Yi Xie
commited on
Commit
·
321f459
1
Parent(s):
d7ffaa3
Add MangaScaleV3 on ESRGAN+ arch
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +26 -0
- converter.py +4 -1
- esrgan_plus/LICENSE +201 -0
- esrgan_plus/README.md +48 -0
- esrgan_plus/codes/auto_test.py +32 -0
- esrgan_plus/codes/data/LRHR_dataset.py +128 -0
- esrgan_plus/codes/data/LRHR_seg_bg_dataset.py +149 -0
- esrgan_plus/codes/data/LR_dataset.py +40 -0
- esrgan_plus/codes/data/__init__.py +37 -0
- esrgan_plus/codes/data/util.py +434 -0
- esrgan_plus/codes/models/SFTGAN_ACD_model.py +261 -0
- esrgan_plus/codes/models/SRGAN_model.py +240 -0
- esrgan_plus/codes/models/SRRaGAN_model.py +251 -0
- esrgan_plus/codes/models/SR_model.py +151 -0
- esrgan_plus/codes/models/__init__.py +20 -0
- esrgan_plus/codes/models/__pycache__/__init__.cpython-310.pyc +0 -0
- esrgan_plus/codes/models/base_model.py +85 -0
- esrgan_plus/codes/models/modules/__pycache__/architecture.cpython-310.pyc +0 -0
- esrgan_plus/codes/models/modules/__pycache__/block.cpython-310.pyc +0 -0
- esrgan_plus/codes/models/modules/__pycache__/spectral_norm.cpython-310.pyc +0 -0
- esrgan_plus/codes/models/modules/architecture.py +394 -0
- esrgan_plus/codes/models/modules/block.py +322 -0
- esrgan_plus/codes/models/modules/loss.py +60 -0
- esrgan_plus/codes/models/modules/seg_arch.py +70 -0
- esrgan_plus/codes/models/modules/sft_arch.py +226 -0
- esrgan_plus/codes/models/modules/spectral_norm.py +149 -0
- esrgan_plus/codes/models/networks.py +155 -0
- esrgan_plus/codes/options/options.py +120 -0
- esrgan_plus/codes/options/test/test_ESRGANplus.json +40 -0
- esrgan_plus/codes/options/test/test_SRGAN.json +37 -0
- esrgan_plus/codes/options/test/test_SRResNet.json +40 -0
- esrgan_plus/codes/options/test/test_sr.json +40 -0
- esrgan_plus/codes/options/train/train_ESRGANplus.json +83 -0
- esrgan_plus/codes/options/train/train_SRGAN.json +87 -0
- esrgan_plus/codes/options/train/train_SRResNet.json +66 -0
- esrgan_plus/codes/options/train/train_sftgan.json +76 -0
- esrgan_plus/codes/options/train/train_sr.json +66 -0
- esrgan_plus/codes/scripts/README.md +8 -0
- esrgan_plus/codes/scripts/back_projection/backprojection.m +20 -0
- esrgan_plus/codes/scripts/back_projection/main_bp.m +22 -0
- esrgan_plus/codes/scripts/back_projection/main_reverse_filter.m +25 -0
- esrgan_plus/codes/scripts/color2gray.py +63 -0
- esrgan_plus/codes/scripts/create_lmdb.py +66 -0
- esrgan_plus/codes/scripts/extract_enlarge_patches.py +64 -0
- esrgan_plus/codes/scripts/extract_subimgs_single.py +88 -0
- esrgan_plus/codes/scripts/generate_mod_LR_bic.m +82 -0
- esrgan_plus/codes/scripts/generate_mod_LR_bic.py +74 -0
- esrgan_plus/codes/scripts/make_gif_video.py +106 -0
- esrgan_plus/codes/scripts/net_interp.py +20 -0
- esrgan_plus/codes/scripts/rename.py +25 -0
.gitignore
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# General
|
| 2 |
+
.DS_Store
|
| 3 |
+
.AppleDouble
|
| 4 |
+
.LSOverride
|
| 5 |
+
|
| 6 |
+
# Icon must end with two \r
|
| 7 |
+
Icon
|
| 8 |
+
|
| 9 |
+
# Thumbnails
|
| 10 |
+
._*
|
| 11 |
+
|
| 12 |
+
# Files that might appear in the root of a volume
|
| 13 |
+
.DocumentRevisions-V100
|
| 14 |
+
.fseventsd
|
| 15 |
+
.Spotlight-V100
|
| 16 |
+
.TemporaryItems
|
| 17 |
+
.Trashes
|
| 18 |
+
.VolumeIcon.icns
|
| 19 |
+
.com.apple.timemachine.donotpresent
|
| 20 |
+
|
| 21 |
+
# Directories potentially created on remote AFP share
|
| 22 |
+
.AppleDB
|
| 23 |
+
.AppleDesktop
|
| 24 |
+
Network Trash Folder
|
| 25 |
+
Temporary Items
|
| 26 |
+
.apdisk
|
converter.py
CHANGED
|
@@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(
|
|
| 28 |
)
|
| 29 |
parser.add_argument('filename')
|
| 30 |
required_args = parser.add_argument_group('required')
|
| 31 |
-
required_args.add_argument('--type', choices=['esrgan_old', 'esrgan_old_lite', 'real_esrgan', 'real_esrgan_compact'], required=True, help='Type of the model')
|
| 32 |
required_args.add_argument('--name', type=str, required=True, help='Name of the model')
|
| 33 |
required_args.add_argument('--scale', type=int, required=True, help='Scale factor of the model')
|
| 34 |
required_args.add_argument('--out-dir', type=str, required=True, help='Output directory')
|
|
@@ -118,6 +118,9 @@ elif args.type == 'real_esrgan':
|
|
| 118 |
elif args.type == 'real_esrgan_compact':
|
| 119 |
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
| 120 |
torch_model = SRVGGNetCompact(num_in_ch=channels, num_out_ch=channels, num_feat=num_features, num_conv=num_convs, upscale=args.scale, act_type='prelu')
|
|
|
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
logger.fatal('Unknown model type: %s', args.type)
|
| 123 |
sys.exit(-1)
|
|
|
|
| 28 |
)
|
| 29 |
parser.add_argument('filename')
|
| 30 |
required_args = parser.add_argument_group('required')
|
| 31 |
+
required_args.add_argument('--type', choices=['esrgan_old', 'esrgan_old_lite', 'real_esrgan', 'real_esrgan_compact', 'esrgan_plus'], required=True, help='Type of the model')
|
| 32 |
required_args.add_argument('--name', type=str, required=True, help='Name of the model')
|
| 33 |
required_args.add_argument('--scale', type=int, required=True, help='Scale factor of the model')
|
| 34 |
required_args.add_argument('--out-dir', type=str, required=True, help='Output directory')
|
|
|
|
| 118 |
elif args.type == 'real_esrgan_compact':
|
| 119 |
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
| 120 |
torch_model = SRVGGNetCompact(num_in_ch=channels, num_out_ch=channels, num_feat=num_features, num_conv=num_convs, upscale=args.scale, act_type='prelu')
|
| 121 |
+
elif args.type == 'esrgan_plus':
|
| 122 |
+
from esrgan_plus.codes.models.modules.architecture import RRDBNet
|
| 123 |
+
torch_model = RRDBNet(in_nc=channels, out_nc=channels, nf=num_features, nb=num_blocks, gc=32, upscale=args.scale)
|
| 124 |
else:
|
| 125 |
logger.fatal('Unknown model type: %s', args.type)
|
| 126 |
sys.exit(-1)
|
esrgan_plus/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
esrgan_plus/README.md
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ESRGAN+ nESRGAN+ Tarsier
|
| 2 |
+
## ICASSP 2020 - ESRGAN+ : Further Improving Enhanced Super-Resolution Generative Adversarial Network
|
| 3 |
+
### [Paper arXiv](https://arxiv.org/abs/2001.08073)
|
| 4 |
+
### [Paper IEEE Xplore](https://ieeexplore.ieee.org/document/9054071)
|
| 5 |
+
## ICPR 2020 - Tarsier: Evolving Noise Injection in Super-Resolution GANs
|
| 6 |
+
### [Paper arXiv](https://arxiv.org/abs/2009.12177)
|
| 7 |
+
|
| 8 |
+
<p align="center">
|
| 9 |
+
<img height="250" src="./figures/noise_per_residual_dense_block.PNG">
|
| 10 |
+
</p>
|
| 11 |
+
|
| 12 |
+
<p align="center">
|
| 13 |
+
<img src="./figures/qualitative_result.PNG">
|
| 14 |
+
</p>
|
| 15 |
+
|
| 16 |
+
### Dependencies
|
| 17 |
+
|
| 18 |
+
- Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
|
| 19 |
+
- [PyTorch >= 1.0.0](https://pytorch.org/)
|
| 20 |
+
- NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
|
| 21 |
+
- Python packages: `pip install numpy opencv-python lmdb tensorboardX`
|
| 22 |
+
|
| 23 |
+
## How to test
|
| 24 |
+
1. Place your low-resolution images in `test_image/LR` folder.
|
| 25 |
+
2. Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1lNky9afqEP-qdxrAwDFPJ1g0ui4x7Sin?usp=sharing) and place them in `test_image/pretrained_models`.
|
| 26 |
+
2. Run the command: `python test_image/test.py test_image/pretrained_models/nESRGANplus.pth` (or any other models).
|
| 27 |
+
3. The results are in `test_image/results` folder.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## How to train
|
| 31 |
+
1. Prepare the datasets which can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU).
|
| 32 |
+
2. Prepare the PSNR-oriented pretrained model (all pretrained models can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1lNky9afqEP-qdxrAwDFPJ1g0ui4x7Sin?usp=sharing)).
|
| 33 |
+
2. Modify the configuration file `codes/options/train/train_ESRGANplus.json`.
|
| 34 |
+
3. Run the command `python train.py -opt codes/options/train/train_ESRGANplus.json`.
|
| 35 |
+
|
| 36 |
+
## Acknowledgement
|
| 37 |
+
- This code is based on [BasicSR](https://github.com/xinntao/BasicSR).
|
| 38 |
+
|
| 39 |
+
## Citation
|
| 40 |
+
|
| 41 |
+
@INPROCEEDINGS{9054071,
|
| 42 |
+
author = {N. C. {Rakotonirina} and A. {Rasoanaivo}},
|
| 43 |
+
booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
| 44 |
+
title={ESRGAN+ : Further Improving Enhanced Super-Resolution Generative Adversarial Network},
|
| 45 |
+
year={2020},
|
| 46 |
+
volume={},
|
| 47 |
+
number={},
|
| 48 |
+
pages={3637-3641},}
|
esrgan_plus/codes/auto_test.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''auto test several models.'''
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
test_json_path = 'options/test/test_esrgan_auto.json'
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def modify_json(json_path, model_name, iteration):
|
| 10 |
+
with open(json_path, 'r+') as json_file:
|
| 11 |
+
config = json.load(json_file)
|
| 12 |
+
|
| 13 |
+
config['name'] = model_name
|
| 14 |
+
config['datasets']['test_1']['name'] = 'pirm_test_{:d}k'.format(iteration)
|
| 15 |
+
# config['datasets']['test_1']['dataroot_LR'] = \
|
| 16 |
+
# '/home/carraz/datasets/PIRM/PIRM_Test_set/LR'
|
| 17 |
+
config['path']['pretrain_model_G'] = \
|
| 18 |
+
'../experiments/{:s}/models/{:d}_G.pth'.format(model_name, iteration*1000)
|
| 19 |
+
json_file.seek(0) # rewind
|
| 20 |
+
json.dump(config, json_file)
|
| 21 |
+
json_file.truncate() # if the new data is smaller than the previous
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
model_iter_dict = {}
|
| 25 |
+
model_iter_dict['100_ESRGAN_SRResNet_pristine_pixel10_minc'] = [80, 85, 90, 95]
|
| 26 |
+
|
| 27 |
+
for model_name, iter_list in model_iter_dict.items():
|
| 28 |
+
for iteration in iter_list:
|
| 29 |
+
modify_json(test_json_path, model_name, iteration)
|
| 30 |
+
# run test scripts
|
| 31 |
+
print('\n\nTesting {:s} {:d}k...'.format(model_name, iteration))
|
| 32 |
+
os.system('python test.py -opt ' + test_json_path)
|
esrgan_plus/codes/data/LRHR_dataset.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
import data.util as util
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LRHRDataset(data.Dataset):
|
| 11 |
+
'''
|
| 12 |
+
Read LR and HR image pairs.
|
| 13 |
+
If only HR image is provided, generate LR image on-the-fly.
|
| 14 |
+
The pair is ensured by 'sorted' function, so please check the name convention.
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
super(LRHRDataset, self).__init__()
|
| 19 |
+
self.opt = opt
|
| 20 |
+
self.paths_LR = None
|
| 21 |
+
self.paths_HR = None
|
| 22 |
+
self.LR_env = None # environment for lmdb
|
| 23 |
+
self.HR_env = None
|
| 24 |
+
|
| 25 |
+
# read image list from subset list txt
|
| 26 |
+
if opt['subset_file'] is not None and opt['phase'] == 'train':
|
| 27 |
+
with open(opt['subset_file']) as f:
|
| 28 |
+
self.paths_HR = sorted([os.path.join(opt['dataroot_HR'], line.rstrip('\n')) \
|
| 29 |
+
for line in f])
|
| 30 |
+
if opt['dataroot_LR'] is not None:
|
| 31 |
+
raise NotImplementedError('Now subset only supports generating LR on-the-fly.')
|
| 32 |
+
else: # read image list from lmdb or image files
|
| 33 |
+
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
|
| 34 |
+
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
|
| 35 |
+
|
| 36 |
+
assert self.paths_HR, 'Error: HR path is empty.'
|
| 37 |
+
if self.paths_LR and self.paths_HR:
|
| 38 |
+
assert len(self.paths_LR) == len(self.paths_HR), \
|
| 39 |
+
'HR and LR datasets have different number of images - {}, {}.'.format(\
|
| 40 |
+
len(self.paths_LR), len(self.paths_HR))
|
| 41 |
+
|
| 42 |
+
self.random_scale_list = [1]
|
| 43 |
+
|
| 44 |
+
def __getitem__(self, index):
|
| 45 |
+
HR_path, LR_path = None, None
|
| 46 |
+
scale = self.opt['scale']
|
| 47 |
+
HR_size = self.opt['HR_size']
|
| 48 |
+
|
| 49 |
+
# get HR image
|
| 50 |
+
HR_path = self.paths_HR[index]
|
| 51 |
+
img_HR = util.read_img(self.HR_env, HR_path)
|
| 52 |
+
# modcrop in the validation / test phase
|
| 53 |
+
if self.opt['phase'] != 'train':
|
| 54 |
+
img_HR = util.modcrop(img_HR, scale)
|
| 55 |
+
# change color space if necessary
|
| 56 |
+
if self.opt['color']:
|
| 57 |
+
img_HR = util.channel_convert(img_HR.shape[2], self.opt['color'], [img_HR])[0]
|
| 58 |
+
|
| 59 |
+
# get LR image
|
| 60 |
+
if self.paths_LR:
|
| 61 |
+
LR_path = self.paths_LR[index]
|
| 62 |
+
img_LR = util.read_img(self.LR_env, LR_path)
|
| 63 |
+
else: # down-sampling on-the-fly
|
| 64 |
+
# randomly scale during training
|
| 65 |
+
if self.opt['phase'] == 'train':
|
| 66 |
+
random_scale = random.choice(self.random_scale_list)
|
| 67 |
+
H_s, W_s, _ = img_HR.shape
|
| 68 |
+
|
| 69 |
+
def _mod(n, random_scale, scale, thres):
|
| 70 |
+
rlt = int(n * random_scale)
|
| 71 |
+
rlt = (rlt // scale) * scale
|
| 72 |
+
return thres if rlt < thres else rlt
|
| 73 |
+
|
| 74 |
+
H_s = _mod(H_s, random_scale, scale, HR_size)
|
| 75 |
+
W_s = _mod(W_s, random_scale, scale, HR_size)
|
| 76 |
+
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
| 77 |
+
# force to 3 channels
|
| 78 |
+
if img_HR.ndim == 2:
|
| 79 |
+
img_HR = cv2.cvtColor(img_HR, cv2.COLOR_GRAY2BGR)
|
| 80 |
+
|
| 81 |
+
H, W, _ = img_HR.shape
|
| 82 |
+
# using matlab imresize
|
| 83 |
+
img_LR = util.imresize_np(img_HR, 1 / scale, True)
|
| 84 |
+
if img_LR.ndim == 2:
|
| 85 |
+
img_LR = np.expand_dims(img_LR, axis=2)
|
| 86 |
+
|
| 87 |
+
if self.opt['phase'] == 'train':
|
| 88 |
+
# if the image size is too small
|
| 89 |
+
H, W, _ = img_HR.shape
|
| 90 |
+
if H < HR_size or W < HR_size:
|
| 91 |
+
img_HR = cv2.resize(
|
| 92 |
+
np.copy(img_HR), (HR_size, HR_size), interpolation=cv2.INTER_LINEAR)
|
| 93 |
+
# using matlab imresize
|
| 94 |
+
img_LR = util.imresize_np(img_HR, 1 / scale, True)
|
| 95 |
+
if img_LR.ndim == 2:
|
| 96 |
+
img_LR = np.expand_dims(img_LR, axis=2)
|
| 97 |
+
|
| 98 |
+
H, W, C = img_LR.shape
|
| 99 |
+
LR_size = HR_size // scale
|
| 100 |
+
|
| 101 |
+
# randomly crop
|
| 102 |
+
rnd_h = random.randint(0, max(0, H - LR_size))
|
| 103 |
+
rnd_w = random.randint(0, max(0, W - LR_size))
|
| 104 |
+
img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
|
| 105 |
+
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
| 106 |
+
img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
|
| 107 |
+
|
| 108 |
+
# augmentation - flip, rotate
|
| 109 |
+
img_LR, img_HR = util.augment([img_LR, img_HR], self.opt['use_flip'], \
|
| 110 |
+
self.opt['use_rot'])
|
| 111 |
+
|
| 112 |
+
# change color space if necessary
|
| 113 |
+
if self.opt['color']:
|
| 114 |
+
img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0] # TODO during val no definetion
|
| 115 |
+
|
| 116 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 117 |
+
if img_HR.shape[2] == 3:
|
| 118 |
+
img_HR = img_HR[:, :, [2, 1, 0]]
|
| 119 |
+
img_LR = img_LR[:, :, [2, 1, 0]]
|
| 120 |
+
img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
|
| 121 |
+
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
|
| 122 |
+
|
| 123 |
+
if LR_path is None:
|
| 124 |
+
LR_path = HR_path
|
| 125 |
+
return {'LR': img_LR, 'HR': img_HR, 'LR_path': LR_path, 'HR_path': HR_path}
|
| 126 |
+
|
| 127 |
+
def __len__(self):
|
| 128 |
+
return len(self.paths_HR)
|
esrgan_plus/codes/data/LRHR_seg_bg_dataset.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import random
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import torch.utils.data as data
|
| 7 |
+
import data.util as util
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LRHRSeg_BG_Dataset(data.Dataset):
|
| 11 |
+
'''
|
| 12 |
+
Read HR image, segmentation probability map; generate LR image, category for SFTGAN
|
| 13 |
+
also sample general scenes for background
|
| 14 |
+
need to generate LR images on-the-fly
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
super(LRHRSeg_BG_Dataset, self).__init__()
|
| 19 |
+
self.opt = opt
|
| 20 |
+
self.paths_LR = None
|
| 21 |
+
self.paths_HR = None
|
| 22 |
+
self.paths_HR_bg = None # HR images for background scenes
|
| 23 |
+
self.LR_env = None # environment for lmdb
|
| 24 |
+
self.HR_env = None
|
| 25 |
+
self.HR_env_bg = None
|
| 26 |
+
|
| 27 |
+
# read image list from lmdb or image files
|
| 28 |
+
self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_HR'])
|
| 29 |
+
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
|
| 30 |
+
self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'], \
|
| 31 |
+
opt['dataroot_HR_bg'])
|
| 32 |
+
|
| 33 |
+
assert self.paths_HR, 'Error: HR path is empty.'
|
| 34 |
+
if self.paths_LR and self.paths_HR:
|
| 35 |
+
assert len(self.paths_LR) == len(self.paths_HR), \
|
| 36 |
+
'HR and LR datasets have different number of images - {}, {}.'.format(\
|
| 37 |
+
len(self.paths_LR), len(self.paths_HR))
|
| 38 |
+
|
| 39 |
+
self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5]
|
| 40 |
+
self.ratio = 10 # 10 OST data samples and 1 DIV2K general data samples(background)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, index):
|
| 43 |
+
HR_path, LR_path = None, None
|
| 44 |
+
scale = self.opt['scale']
|
| 45 |
+
HR_size = self.opt['HR_size']
|
| 46 |
+
|
| 47 |
+
# get HR image
|
| 48 |
+
if self.opt['phase'] == 'train' and \
|
| 49 |
+
random.choice(list(range(self.ratio))) == 0: # read background images
|
| 50 |
+
bg_index = random.randint(0, len(self.paths_HR_bg) - 1)
|
| 51 |
+
HR_path = self.paths_HR_bg[bg_index]
|
| 52 |
+
img_HR = util.read_img(self.HR_env_bg, HR_path)
|
| 53 |
+
seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0)
|
| 54 |
+
seg[0, :, :] = 1 # background
|
| 55 |
+
else:
|
| 56 |
+
HR_path = self.paths_HR[index]
|
| 57 |
+
img_HR = util.read_img(self.HR_env, HR_path)
|
| 58 |
+
seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth'))
|
| 59 |
+
# read segmentatin files, you should change it to your settings.
|
| 60 |
+
|
| 61 |
+
# modcrop in the validation / test phase
|
| 62 |
+
if self.opt['phase'] != 'train':
|
| 63 |
+
img_HR = util.modcrop(img_HR, 8)
|
| 64 |
+
|
| 65 |
+
seg = np.transpose(seg.numpy(), (1, 2, 0))
|
| 66 |
+
|
| 67 |
+
# get LR image
|
| 68 |
+
if self.paths_LR:
|
| 69 |
+
LR_path = self.paths_LR[index]
|
| 70 |
+
img_LR = util.read_img(self.LR_env, LR_path)
|
| 71 |
+
else: # down-sampling on-the-fly
|
| 72 |
+
# randomly scale during training
|
| 73 |
+
if self.opt['phase'] == 'train':
|
| 74 |
+
random_scale = random.choice(self.random_scale_list)
|
| 75 |
+
H_s, W_s, _ = seg.shape
|
| 76 |
+
|
| 77 |
+
def _mod(n, random_scale, scale, thres):
|
| 78 |
+
rlt = int(n * random_scale)
|
| 79 |
+
rlt = (rlt // scale) * scale
|
| 80 |
+
return thres if rlt < thres else rlt
|
| 81 |
+
|
| 82 |
+
H_s = _mod(H_s, random_scale, scale, HR_size)
|
| 83 |
+
W_s = _mod(W_s, random_scale, scale, HR_size)
|
| 84 |
+
img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR)
|
| 85 |
+
seg = cv2.resize(np.copy(seg), (W_s, H_s), interpolation=cv2.INTER_NEAREST)
|
| 86 |
+
|
| 87 |
+
H, W, _ = img_HR.shape
|
| 88 |
+
# using matlab imresize
|
| 89 |
+
img_LR = util.imresize_np(img_HR, 1 / scale, True)
|
| 90 |
+
if img_LR.ndim == 2:
|
| 91 |
+
img_LR = np.expand_dims(img_LR, axis=2)
|
| 92 |
+
|
| 93 |
+
H, W, C = img_LR.shape
|
| 94 |
+
if self.opt['phase'] == 'train':
|
| 95 |
+
LR_size = HR_size // scale
|
| 96 |
+
|
| 97 |
+
# randomly crop
|
| 98 |
+
rnd_h = random.randint(0, max(0, H - LR_size))
|
| 99 |
+
rnd_w = random.randint(0, max(0, W - LR_size))
|
| 100 |
+
img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :]
|
| 101 |
+
rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
|
| 102 |
+
img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
|
| 103 |
+
seg = seg[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :]
|
| 104 |
+
|
| 105 |
+
# augmentation - flip, rotate
|
| 106 |
+
img_LR, img_HR, seg = util.augment([img_LR, img_HR, seg], self.opt['use_flip'],
|
| 107 |
+
self.opt['use_rot'])
|
| 108 |
+
|
| 109 |
+
# category
|
| 110 |
+
if 'building' in HR_path:
|
| 111 |
+
category = 1
|
| 112 |
+
elif 'plant' in HR_path:
|
| 113 |
+
category = 2
|
| 114 |
+
elif 'mountain' in HR_path:
|
| 115 |
+
category = 3
|
| 116 |
+
elif 'water' in HR_path:
|
| 117 |
+
category = 4
|
| 118 |
+
elif 'sky' in HR_path:
|
| 119 |
+
category = 5
|
| 120 |
+
elif 'grass' in HR_path:
|
| 121 |
+
category = 6
|
| 122 |
+
elif 'animal' in HR_path:
|
| 123 |
+
category = 7
|
| 124 |
+
else:
|
| 125 |
+
category = 0 # background
|
| 126 |
+
else:
|
| 127 |
+
category = -1 # during val, useless
|
| 128 |
+
|
| 129 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 130 |
+
if img_HR.shape[2] == 3:
|
| 131 |
+
img_HR = img_HR[:, :, [2, 1, 0]]
|
| 132 |
+
img_LR = img_LR[:, :, [2, 1, 0]]
|
| 133 |
+
img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float()
|
| 134 |
+
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
|
| 135 |
+
seg = torch.from_numpy(np.ascontiguousarray(np.transpose(seg, (2, 0, 1)))).float()
|
| 136 |
+
|
| 137 |
+
if LR_path is None:
|
| 138 |
+
LR_path = HR_path
|
| 139 |
+
return {
|
| 140 |
+
'LR': img_LR,
|
| 141 |
+
'HR': img_HR,
|
| 142 |
+
'seg': seg,
|
| 143 |
+
'category': category,
|
| 144 |
+
'LR_path': LR_path,
|
| 145 |
+
'HR_path': HR_path
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
def __len__(self):
|
| 149 |
+
return len(self.paths_HR)
|
esrgan_plus/codes/data/LR_dataset.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.utils.data as data
|
| 4 |
+
import data.util as util
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class LRDataset(data.Dataset):
|
| 8 |
+
'''Read LR images only in the test phase.'''
|
| 9 |
+
|
| 10 |
+
def __init__(self, opt):
|
| 11 |
+
super(LRDataset, self).__init__()
|
| 12 |
+
self.opt = opt
|
| 13 |
+
self.paths_LR = None
|
| 14 |
+
self.LR_env = None # environment for lmdb
|
| 15 |
+
|
| 16 |
+
# read image list from lmdb or image files
|
| 17 |
+
self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR'])
|
| 18 |
+
assert self.paths_LR, 'Error: LR paths are empty.'
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, index):
|
| 21 |
+
LR_path = None
|
| 22 |
+
|
| 23 |
+
# get LR image
|
| 24 |
+
LR_path = self.paths_LR[index]
|
| 25 |
+
img_LR = util.read_img(self.LR_env, LR_path)
|
| 26 |
+
H, W, C = img_LR.shape
|
| 27 |
+
|
| 28 |
+
# change color space if necessary
|
| 29 |
+
if self.opt['color']:
|
| 30 |
+
img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0]
|
| 31 |
+
|
| 32 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
| 33 |
+
if img_LR.shape[2] == 3:
|
| 34 |
+
img_LR = img_LR[:, :, [2, 1, 0]]
|
| 35 |
+
img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float()
|
| 36 |
+
|
| 37 |
+
return {'LR': img_LR, 'LR_path': LR_path}
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.paths_LR)
|
esrgan_plus/codes/data/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''create dataset and dataloader'''
|
| 2 |
+
import logging
|
| 3 |
+
import torch.utils.data
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def create_dataloader(dataset, dataset_opt):
|
| 7 |
+
'''create dataloader '''
|
| 8 |
+
phase = dataset_opt['phase']
|
| 9 |
+
if phase == 'train':
|
| 10 |
+
return torch.utils.data.DataLoader(
|
| 11 |
+
dataset,
|
| 12 |
+
batch_size=dataset_opt['batch_size'],
|
| 13 |
+
shuffle=dataset_opt['use_shuffle'],
|
| 14 |
+
num_workers=dataset_opt['n_workers'],
|
| 15 |
+
drop_last=True,
|
| 16 |
+
pin_memory=True)
|
| 17 |
+
else:
|
| 18 |
+
return torch.utils.data.DataLoader(
|
| 19 |
+
dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def create_dataset(dataset_opt):
|
| 23 |
+
'''create dataset'''
|
| 24 |
+
mode = dataset_opt['mode']
|
| 25 |
+
if mode == 'LR':
|
| 26 |
+
from data.LR_dataset import LRDataset as D
|
| 27 |
+
elif mode == 'LRHR':
|
| 28 |
+
from data.LRHR_dataset import LRHRDataset as D
|
| 29 |
+
elif mode == 'LRHRseg_bg':
|
| 30 |
+
from data.LRHR_seg_bg_dataset import LRHRSeg_BG_Dataset as D
|
| 31 |
+
else:
|
| 32 |
+
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
| 33 |
+
dataset = D(dataset_opt)
|
| 34 |
+
logger = logging.getLogger('base')
|
| 35 |
+
logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
|
| 36 |
+
dataset_opt['name']))
|
| 37 |
+
return dataset
|
esrgan_plus/codes/data/util.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import pickle
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
import lmdb
|
| 7 |
+
import torch
|
| 8 |
+
import cv2
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
| 12 |
+
|
| 13 |
+
####################
|
| 14 |
+
# Files & IO
|
| 15 |
+
####################
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def is_image_file(filename):
|
| 19 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_paths_from_images(path):
|
| 23 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
| 24 |
+
images = []
|
| 25 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
| 26 |
+
for fname in sorted(fnames):
|
| 27 |
+
if is_image_file(fname):
|
| 28 |
+
img_path = os.path.join(dirpath, fname)
|
| 29 |
+
images.append(img_path)
|
| 30 |
+
assert images, '{:s} has no valid image file'.format(path)
|
| 31 |
+
return images
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _get_paths_from_lmdb(dataroot):
|
| 35 |
+
env = lmdb.open(dataroot, readonly=True, lock=False, readahead=False, meminit=False)
|
| 36 |
+
keys_cache_file = os.path.join(dataroot, '_keys_cache.p')
|
| 37 |
+
logger = logging.getLogger('base')
|
| 38 |
+
if os.path.isfile(keys_cache_file):
|
| 39 |
+
logger.info('Read lmdb keys from cache: {}'.format(keys_cache_file))
|
| 40 |
+
keys = pickle.load(open(keys_cache_file, "rb"))
|
| 41 |
+
else:
|
| 42 |
+
with env.begin(write=False) as txn:
|
| 43 |
+
logger.info('Creating lmdb keys cache: {}'.format(keys_cache_file))
|
| 44 |
+
keys = [key.decode('ascii') for key, _ in txn.cursor()]
|
| 45 |
+
pickle.dump(keys, open(keys_cache_file, 'wb'))
|
| 46 |
+
paths = sorted([key for key in keys if not key.endswith('.meta')])
|
| 47 |
+
return env, paths
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_image_paths(data_type, dataroot):
|
| 51 |
+
env, paths = None, None
|
| 52 |
+
if dataroot is not None:
|
| 53 |
+
if data_type == 'lmdb':
|
| 54 |
+
env, paths = _get_paths_from_lmdb(dataroot)
|
| 55 |
+
elif data_type == 'img':
|
| 56 |
+
paths = sorted(_get_paths_from_images(dataroot))
|
| 57 |
+
else:
|
| 58 |
+
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
|
| 59 |
+
return env, paths
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _read_lmdb_img(env, path):
|
| 63 |
+
with env.begin(write=False) as txn:
|
| 64 |
+
buf = txn.get(path.encode('ascii'))
|
| 65 |
+
buf_meta = txn.get((path + '.meta').encode('ascii')).decode('ascii')
|
| 66 |
+
img_flat = np.frombuffer(buf, dtype=np.uint8)
|
| 67 |
+
H, W, C = [int(s) for s in buf_meta.split(',')]
|
| 68 |
+
img = img_flat.reshape(H, W, C)
|
| 69 |
+
return img
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def read_img(env, path):
|
| 73 |
+
# read image by cv2 or from lmdb
|
| 74 |
+
# return: Numpy float32, HWC, BGR, [0,1]
|
| 75 |
+
if env is None: # img
|
| 76 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 77 |
+
else:
|
| 78 |
+
img = _read_lmdb_img(env, path)
|
| 79 |
+
img = img.astype(np.float32) / 255.
|
| 80 |
+
if img.ndim == 2:
|
| 81 |
+
img = np.expand_dims(img, axis=2)
|
| 82 |
+
# some images have 4 channels
|
| 83 |
+
if img.shape[2] > 3:
|
| 84 |
+
img = img[:, :, :3]
|
| 85 |
+
return img
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
####################
|
| 89 |
+
# image processing
|
| 90 |
+
# process on numpy image
|
| 91 |
+
####################
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def augment(img_list, hflip=True, rot=True):
|
| 95 |
+
# horizontal flip OR rotate
|
| 96 |
+
hflip = hflip and random.random() < 0.5
|
| 97 |
+
vflip = rot and random.random() < 0.5
|
| 98 |
+
rot90 = rot and random.random() < 0.5
|
| 99 |
+
|
| 100 |
+
def _augment(img):
|
| 101 |
+
if hflip: img = img[:, ::-1, :]
|
| 102 |
+
if vflip: img = img[::-1, :, :]
|
| 103 |
+
if rot90: img = img.transpose(1, 0, 2)
|
| 104 |
+
return img
|
| 105 |
+
|
| 106 |
+
return [_augment(img) for img in img_list]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def channel_convert(in_c, tar_type, img_list):
|
| 110 |
+
# conversion among BGR, gray and y
|
| 111 |
+
if in_c == 3 and tar_type == 'gray': # BGR to gray
|
| 112 |
+
gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
|
| 113 |
+
return [np.expand_dims(img, axis=2) for img in gray_list]
|
| 114 |
+
elif in_c == 3 and tar_type == 'y': # BGR to y
|
| 115 |
+
y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
|
| 116 |
+
return [np.expand_dims(img, axis=2) for img in y_list]
|
| 117 |
+
elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
|
| 118 |
+
return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
|
| 119 |
+
else:
|
| 120 |
+
return img_list
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def rgb2ycbcr(img, only_y=True):
|
| 124 |
+
'''same as matlab rgb2ycbcr
|
| 125 |
+
only_y: only return Y channel
|
| 126 |
+
Input:
|
| 127 |
+
uint8, [0, 255]
|
| 128 |
+
float, [0, 1]
|
| 129 |
+
'''
|
| 130 |
+
in_img_type = img.dtype
|
| 131 |
+
img.astype(np.float32)
|
| 132 |
+
if in_img_type != np.uint8:
|
| 133 |
+
img *= 255.
|
| 134 |
+
# convert
|
| 135 |
+
if only_y:
|
| 136 |
+
rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
|
| 137 |
+
else:
|
| 138 |
+
rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
|
| 139 |
+
[24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
|
| 140 |
+
if in_img_type == np.uint8:
|
| 141 |
+
rlt = rlt.round()
|
| 142 |
+
else:
|
| 143 |
+
rlt /= 255.
|
| 144 |
+
return rlt.astype(in_img_type)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def bgr2ycbcr(img, only_y=True):
|
| 148 |
+
'''bgr version of rgb2ycbcr
|
| 149 |
+
only_y: only return Y channel
|
| 150 |
+
Input:
|
| 151 |
+
uint8, [0, 255]
|
| 152 |
+
float, [0, 1]
|
| 153 |
+
'''
|
| 154 |
+
in_img_type = img.dtype
|
| 155 |
+
img.astype(np.float32)
|
| 156 |
+
if in_img_type != np.uint8:
|
| 157 |
+
img *= 255.
|
| 158 |
+
# convert
|
| 159 |
+
if only_y:
|
| 160 |
+
rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
|
| 161 |
+
else:
|
| 162 |
+
rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
|
| 163 |
+
[65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
|
| 164 |
+
if in_img_type == np.uint8:
|
| 165 |
+
rlt = rlt.round()
|
| 166 |
+
else:
|
| 167 |
+
rlt /= 255.
|
| 168 |
+
return rlt.astype(in_img_type)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def ycbcr2rgb(img):
|
| 172 |
+
'''same as matlab ycbcr2rgb
|
| 173 |
+
Input:
|
| 174 |
+
uint8, [0, 255]
|
| 175 |
+
float, [0, 1]
|
| 176 |
+
'''
|
| 177 |
+
in_img_type = img.dtype
|
| 178 |
+
img.astype(np.float32)
|
| 179 |
+
if in_img_type != np.uint8:
|
| 180 |
+
img *= 255.
|
| 181 |
+
# convert
|
| 182 |
+
rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
|
| 183 |
+
[0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
|
| 184 |
+
if in_img_type == np.uint8:
|
| 185 |
+
rlt = rlt.round()
|
| 186 |
+
else:
|
| 187 |
+
rlt /= 255.
|
| 188 |
+
return rlt.astype(in_img_type)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def modcrop(img_in, scale):
|
| 192 |
+
# img_in: Numpy, HWC or HW
|
| 193 |
+
img = np.copy(img_in)
|
| 194 |
+
if img.ndim == 2:
|
| 195 |
+
H, W = img.shape
|
| 196 |
+
H_r, W_r = H % scale, W % scale
|
| 197 |
+
img = img[:H - H_r, :W - W_r]
|
| 198 |
+
elif img.ndim == 3:
|
| 199 |
+
H, W, C = img.shape
|
| 200 |
+
H_r, W_r = H % scale, W % scale
|
| 201 |
+
img = img[:H - H_r, :W - W_r, :]
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
|
| 204 |
+
return img
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
####################
|
| 208 |
+
# Functions
|
| 209 |
+
####################
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
# matlab 'imresize' function, now only support 'bicubic'
|
| 213 |
+
def cubic(x):
|
| 214 |
+
absx = torch.abs(x)
|
| 215 |
+
absx2 = absx**2
|
| 216 |
+
absx3 = absx**3
|
| 217 |
+
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
|
| 218 |
+
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
| 222 |
+
if (scale < 1) and (antialiasing):
|
| 223 |
+
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
|
| 224 |
+
kernel_width = kernel_width / scale
|
| 225 |
+
|
| 226 |
+
# Output-space coordinates
|
| 227 |
+
x = torch.linspace(1, out_length, out_length)
|
| 228 |
+
|
| 229 |
+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
|
| 230 |
+
# in output space maps to 0.5 in input space, and 0.5+scale in output
|
| 231 |
+
# space maps to 1.5 in input space.
|
| 232 |
+
u = x / scale + 0.5 * (1 - 1 / scale)
|
| 233 |
+
|
| 234 |
+
# What is the left-most pixel that can be involved in the computation?
|
| 235 |
+
left = torch.floor(u - kernel_width / 2)
|
| 236 |
+
|
| 237 |
+
# What is the maximum number of pixels that can be involved in the
|
| 238 |
+
# computation? Note: it's OK to use an extra pixel here; if the
|
| 239 |
+
# corresponding weights are all zero, it will be eliminated at the end
|
| 240 |
+
# of this function.
|
| 241 |
+
P = math.ceil(kernel_width) + 2
|
| 242 |
+
|
| 243 |
+
# The indices of the input pixels involved in computing the k-th output
|
| 244 |
+
# pixel are in row k of the indices matrix.
|
| 245 |
+
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
|
| 246 |
+
1, P).expand(out_length, P)
|
| 247 |
+
|
| 248 |
+
# The weights used to compute the k-th output pixel are in row k of the
|
| 249 |
+
# weights matrix.
|
| 250 |
+
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
|
| 251 |
+
# apply cubic kernel
|
| 252 |
+
if (scale < 1) and (antialiasing):
|
| 253 |
+
weights = scale * cubic(distance_to_center * scale)
|
| 254 |
+
else:
|
| 255 |
+
weights = cubic(distance_to_center)
|
| 256 |
+
# Normalize the weights matrix so that each row sums to 1.
|
| 257 |
+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
|
| 258 |
+
weights = weights / weights_sum.expand(out_length, P)
|
| 259 |
+
|
| 260 |
+
# If a column in weights is all zero, get rid of it. only consider the first and last column.
|
| 261 |
+
weights_zero_tmp = torch.sum((weights == 0), 0)
|
| 262 |
+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
|
| 263 |
+
indices = indices.narrow(1, 1, P - 2)
|
| 264 |
+
weights = weights.narrow(1, 1, P - 2)
|
| 265 |
+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
|
| 266 |
+
indices = indices.narrow(1, 0, P - 2)
|
| 267 |
+
weights = weights.narrow(1, 0, P - 2)
|
| 268 |
+
weights = weights.contiguous()
|
| 269 |
+
indices = indices.contiguous()
|
| 270 |
+
sym_len_s = -indices.min() + 1
|
| 271 |
+
sym_len_e = indices.max() - in_length
|
| 272 |
+
indices = indices + sym_len_s - 1
|
| 273 |
+
return weights, indices, int(sym_len_s), int(sym_len_e)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def imresize(img, scale, antialiasing=True):
|
| 277 |
+
# Now the scale should be the same for H and W
|
| 278 |
+
# input: img: CHW RGB [0,1]
|
| 279 |
+
# output: CHW RGB [0,1] w/o round
|
| 280 |
+
|
| 281 |
+
in_C, in_H, in_W = img.size()
|
| 282 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
| 283 |
+
kernel_width = 4
|
| 284 |
+
kernel = 'cubic'
|
| 285 |
+
|
| 286 |
+
# Return the desired dimension order for performing the resize. The
|
| 287 |
+
# strategy is to perform the resize first along the dimension with the
|
| 288 |
+
# smallest scale factor.
|
| 289 |
+
# Now we do not support this.
|
| 290 |
+
|
| 291 |
+
# get weights and indices
|
| 292 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
| 293 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
| 294 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
| 295 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
| 296 |
+
# process H dimension
|
| 297 |
+
# symmetric copying
|
| 298 |
+
img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
|
| 299 |
+
img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
|
| 300 |
+
|
| 301 |
+
sym_patch = img[:, :sym_len_Hs, :]
|
| 302 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 303 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 304 |
+
img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
|
| 305 |
+
|
| 306 |
+
sym_patch = img[:, -sym_len_He:, :]
|
| 307 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 308 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 309 |
+
img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
| 310 |
+
|
| 311 |
+
out_1 = torch.FloatTensor(in_C, out_H, in_W)
|
| 312 |
+
kernel_width = weights_H.size(1)
|
| 313 |
+
for i in range(out_H):
|
| 314 |
+
idx = int(indices_H[i][0])
|
| 315 |
+
out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
| 316 |
+
out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
| 317 |
+
out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
|
| 318 |
+
|
| 319 |
+
# process W dimension
|
| 320 |
+
# symmetric copying
|
| 321 |
+
out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
|
| 322 |
+
out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
|
| 323 |
+
|
| 324 |
+
sym_patch = out_1[:, :, :sym_len_Ws]
|
| 325 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 326 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 327 |
+
out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
|
| 328 |
+
|
| 329 |
+
sym_patch = out_1[:, :, -sym_len_We:]
|
| 330 |
+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
|
| 331 |
+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
|
| 332 |
+
out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
| 333 |
+
|
| 334 |
+
out_2 = torch.FloatTensor(in_C, out_H, out_W)
|
| 335 |
+
kernel_width = weights_W.size(1)
|
| 336 |
+
for i in range(out_W):
|
| 337 |
+
idx = int(indices_W[i][0])
|
| 338 |
+
out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
|
| 339 |
+
out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
|
| 340 |
+
out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
|
| 341 |
+
|
| 342 |
+
return out_2
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def imresize_np(img, scale, antialiasing=True):
|
| 346 |
+
# Now the scale should be the same for H and W
|
| 347 |
+
# input: img: Numpy, HWC BGR [0,1]
|
| 348 |
+
# output: HWC BGR [0,1] w/o round
|
| 349 |
+
img = torch.from_numpy(img)
|
| 350 |
+
|
| 351 |
+
in_H, in_W, in_C = img.size()
|
| 352 |
+
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
|
| 353 |
+
kernel_width = 4
|
| 354 |
+
kernel = 'cubic'
|
| 355 |
+
|
| 356 |
+
# Return the desired dimension order for performing the resize. The
|
| 357 |
+
# strategy is to perform the resize first along the dimension with the
|
| 358 |
+
# smallest scale factor.
|
| 359 |
+
# Now we do not support this.
|
| 360 |
+
|
| 361 |
+
# get weights and indices
|
| 362 |
+
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
|
| 363 |
+
in_H, out_H, scale, kernel, kernel_width, antialiasing)
|
| 364 |
+
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
|
| 365 |
+
in_W, out_W, scale, kernel, kernel_width, antialiasing)
|
| 366 |
+
# process H dimension
|
| 367 |
+
# symmetric copying
|
| 368 |
+
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
|
| 369 |
+
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
|
| 370 |
+
|
| 371 |
+
sym_patch = img[:sym_len_Hs, :, :]
|
| 372 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
| 373 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
| 374 |
+
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
|
| 375 |
+
|
| 376 |
+
sym_patch = img[-sym_len_He:, :, :]
|
| 377 |
+
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
|
| 378 |
+
sym_patch_inv = sym_patch.index_select(0, inv_idx)
|
| 379 |
+
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
|
| 380 |
+
|
| 381 |
+
out_1 = torch.FloatTensor(out_H, in_W, in_C)
|
| 382 |
+
kernel_width = weights_H.size(1)
|
| 383 |
+
for i in range(out_H):
|
| 384 |
+
idx = int(indices_H[i][0])
|
| 385 |
+
out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
|
| 386 |
+
out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
|
| 387 |
+
out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
|
| 388 |
+
|
| 389 |
+
# process W dimension
|
| 390 |
+
# symmetric copying
|
| 391 |
+
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
|
| 392 |
+
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
|
| 393 |
+
|
| 394 |
+
sym_patch = out_1[:, :sym_len_Ws, :]
|
| 395 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 396 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 397 |
+
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
|
| 398 |
+
|
| 399 |
+
sym_patch = out_1[:, -sym_len_We:, :]
|
| 400 |
+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
|
| 401 |
+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
|
| 402 |
+
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
|
| 403 |
+
|
| 404 |
+
out_2 = torch.FloatTensor(out_H, out_W, in_C)
|
| 405 |
+
kernel_width = weights_W.size(1)
|
| 406 |
+
for i in range(out_W):
|
| 407 |
+
idx = int(indices_W[i][0])
|
| 408 |
+
out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
|
| 409 |
+
out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
|
| 410 |
+
out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
|
| 411 |
+
|
| 412 |
+
return out_2.numpy()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if __name__ == '__main__':
|
| 416 |
+
# test imresize function
|
| 417 |
+
# read images
|
| 418 |
+
img = cv2.imread('test.png')
|
| 419 |
+
img = img * 1.0 / 255
|
| 420 |
+
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
|
| 421 |
+
# imresize
|
| 422 |
+
scale = 1 / 4
|
| 423 |
+
import time
|
| 424 |
+
total_time = 0
|
| 425 |
+
for i in range(10):
|
| 426 |
+
start_time = time.time()
|
| 427 |
+
rlt = imresize(img, scale, antialiasing=True)
|
| 428 |
+
use_time = time.time() - start_time
|
| 429 |
+
total_time += use_time
|
| 430 |
+
print('average time: {}'.format(total_time / 10))
|
| 431 |
+
|
| 432 |
+
import torchvision.utils
|
| 433 |
+
torchvision.utils.save_image(
|
| 434 |
+
(rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, normalize=False)
|
esrgan_plus/codes/models/SFTGAN_ACD_model.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import lr_scheduler
|
| 8 |
+
|
| 9 |
+
import models.networks as networks
|
| 10 |
+
from .base_model import BaseModel
|
| 11 |
+
from models.modules.loss import GANLoss, GradientPenaltyLoss
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('base')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SFTGAN_ACD_Model(BaseModel):
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
super(SFTGAN_ACD_Model, self).__init__(opt)
|
| 19 |
+
train_opt = opt['train']
|
| 20 |
+
|
| 21 |
+
# define networks and load pretrained models
|
| 22 |
+
self.netG = networks.define_G(opt).to(self.device) # G
|
| 23 |
+
if self.is_train:
|
| 24 |
+
self.netD = networks.define_D(opt).to(self.device) # D
|
| 25 |
+
self.netG.train()
|
| 26 |
+
self.netD.train()
|
| 27 |
+
self.load() # load G and D if needed
|
| 28 |
+
|
| 29 |
+
# define losses, optimizer and scheduler
|
| 30 |
+
if self.is_train:
|
| 31 |
+
# G pixel loss
|
| 32 |
+
if train_opt['pixel_weight'] > 0:
|
| 33 |
+
l_pix_type = train_opt['pixel_criterion']
|
| 34 |
+
if l_pix_type == 'l1':
|
| 35 |
+
self.cri_pix = nn.L1Loss().to(self.device)
|
| 36 |
+
elif l_pix_type == 'l2':
|
| 37 |
+
self.cri_pix = nn.MSELoss().to(self.device)
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
| 40 |
+
self.l_pix_w = train_opt['pixel_weight']
|
| 41 |
+
else:
|
| 42 |
+
logging.info('Remove pixel loss.')
|
| 43 |
+
self.cri_pix = None
|
| 44 |
+
|
| 45 |
+
# G feature loss
|
| 46 |
+
if train_opt['feature_weight'] > 0:
|
| 47 |
+
l_fea_type = train_opt['feature_criterion']
|
| 48 |
+
if l_fea_type == 'l1':
|
| 49 |
+
self.cri_fea = nn.L1Loss().to(self.device)
|
| 50 |
+
elif l_fea_type == 'l2':
|
| 51 |
+
self.cri_fea = nn.MSELoss().to(self.device)
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
| 54 |
+
self.l_fea_w = train_opt['feature_weight']
|
| 55 |
+
else:
|
| 56 |
+
logging.info('Remove feature loss.')
|
| 57 |
+
self.cri_fea = None
|
| 58 |
+
if self.cri_fea: # load VGG perceptual loss
|
| 59 |
+
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
| 60 |
+
|
| 61 |
+
# GD gan loss
|
| 62 |
+
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
| 63 |
+
self.l_gan_w = train_opt['gan_weight']
|
| 64 |
+
# D_update_ratio and D_init_iters are for WGAN
|
| 65 |
+
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
| 66 |
+
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
| 67 |
+
|
| 68 |
+
if train_opt['gan_type'] == 'wgan-gp':
|
| 69 |
+
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
| 70 |
+
# gradient penalty loss
|
| 71 |
+
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
| 72 |
+
self.l_gp_w = train_opt['gp_weigth']
|
| 73 |
+
|
| 74 |
+
# D cls loss
|
| 75 |
+
self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device)
|
| 76 |
+
# ignore background, since bg images may conflict with other classes
|
| 77 |
+
|
| 78 |
+
# optimizers
|
| 79 |
+
# G
|
| 80 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
| 81 |
+
optim_params_SFT = []
|
| 82 |
+
optim_params_other = []
|
| 83 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
| 84 |
+
if 'SFT' in k or 'Cond' in k:
|
| 85 |
+
optim_params_SFT.append(v)
|
| 86 |
+
else:
|
| 87 |
+
optim_params_other.append(v)
|
| 88 |
+
self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G']*5, \
|
| 89 |
+
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
| 90 |
+
self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], \
|
| 91 |
+
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
| 92 |
+
self.optimizers.append(self.optimizer_G_SFT)
|
| 93 |
+
self.optimizers.append(self.optimizer_G_other)
|
| 94 |
+
# D
|
| 95 |
+
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
| 96 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
| 97 |
+
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
| 98 |
+
self.optimizers.append(self.optimizer_D)
|
| 99 |
+
|
| 100 |
+
# schedulers
|
| 101 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
| 102 |
+
for optimizer in self.optimizers:
|
| 103 |
+
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
| 104 |
+
train_opt['lr_steps'], train_opt['lr_gamma']))
|
| 105 |
+
else:
|
| 106 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
| 107 |
+
|
| 108 |
+
self.log_dict = OrderedDict()
|
| 109 |
+
# print network
|
| 110 |
+
self.print_network()
|
| 111 |
+
|
| 112 |
+
def feed_data(self, data, need_HR=True):
|
| 113 |
+
# LR
|
| 114 |
+
self.var_L = data['LR'].to(self.device)
|
| 115 |
+
# seg
|
| 116 |
+
self.var_seg = data['seg'].to(self.device)
|
| 117 |
+
# category
|
| 118 |
+
self.var_cat = data['category'].long().to(self.device)
|
| 119 |
+
|
| 120 |
+
if need_HR: # train or val
|
| 121 |
+
self.var_H = data['HR'].to(self.device)
|
| 122 |
+
|
| 123 |
+
def optimize_parameters(self, step):
|
| 124 |
+
# G
|
| 125 |
+
self.optimizer_G_SFT.zero_grad()
|
| 126 |
+
self.optimizer_G_other.zero_grad()
|
| 127 |
+
self.fake_H = self.netG((self.var_L, self.var_seg))
|
| 128 |
+
|
| 129 |
+
l_g_total = 0
|
| 130 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 131 |
+
if self.cri_pix: # pixel loss
|
| 132 |
+
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
| 133 |
+
l_g_total += l_g_pix
|
| 134 |
+
if self.cri_fea: # feature loss
|
| 135 |
+
real_fea = self.netF(self.var_H).detach()
|
| 136 |
+
fake_fea = self.netF(self.fake_H)
|
| 137 |
+
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
| 138 |
+
l_g_total += l_g_fea
|
| 139 |
+
# G gan + cls loss
|
| 140 |
+
pred_g_fake, cls_g_fake = self.netD(self.fake_H)
|
| 141 |
+
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
| 142 |
+
l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat)
|
| 143 |
+
l_g_total += l_g_gan
|
| 144 |
+
l_g_total += l_g_cls
|
| 145 |
+
|
| 146 |
+
l_g_total.backward()
|
| 147 |
+
self.optimizer_G_SFT.step()
|
| 148 |
+
if step > 20000:
|
| 149 |
+
self.optimizer_G_other.step()
|
| 150 |
+
|
| 151 |
+
# D
|
| 152 |
+
self.optimizer_D.zero_grad()
|
| 153 |
+
l_d_total = 0
|
| 154 |
+
# real data
|
| 155 |
+
pred_d_real, cls_d_real = self.netD(self.var_H)
|
| 156 |
+
l_d_real = self.cri_gan(pred_d_real, True)
|
| 157 |
+
l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat)
|
| 158 |
+
# fake data
|
| 159 |
+
pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
| 160 |
+
l_d_fake = self.cri_gan(pred_d_fake, False)
|
| 161 |
+
l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat)
|
| 162 |
+
|
| 163 |
+
l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake
|
| 164 |
+
|
| 165 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 166 |
+
batch_size = self.var_H.size(0)
|
| 167 |
+
if self.random_pt.size(0) != batch_size:
|
| 168 |
+
self.random_pt.resize_(batch_size, 1, 1, 1)
|
| 169 |
+
self.random_pt.uniform_() # Draw random interpolation points
|
| 170 |
+
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_H
|
| 171 |
+
interp.requires_grad = True
|
| 172 |
+
interp_crit, _ = self.netD(interp)
|
| 173 |
+
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls?
|
| 174 |
+
l_d_total += l_d_gp
|
| 175 |
+
|
| 176 |
+
l_d_total.backward()
|
| 177 |
+
self.optimizer_D.step()
|
| 178 |
+
|
| 179 |
+
# set log
|
| 180 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 181 |
+
# G
|
| 182 |
+
if self.cri_pix:
|
| 183 |
+
self.log_dict['l_g_pix'] = l_g_pix.item()
|
| 184 |
+
if self.cri_fea:
|
| 185 |
+
self.log_dict['l_g_fea'] = l_g_fea.item()
|
| 186 |
+
self.log_dict['l_g_gan'] = l_g_gan.item()
|
| 187 |
+
# D
|
| 188 |
+
self.log_dict['l_d_real'] = l_d_real.item()
|
| 189 |
+
self.log_dict['l_d_fake'] = l_d_fake.item()
|
| 190 |
+
self.log_dict['l_d_cls_real'] = l_d_cls_real.item()
|
| 191 |
+
self.log_dict['l_d_cls_fake'] = l_d_cls_fake.item()
|
| 192 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 193 |
+
self.log_dict['l_d_gp'] = l_d_gp.item()
|
| 194 |
+
# D outputs
|
| 195 |
+
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
| 196 |
+
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
| 197 |
+
|
| 198 |
+
def test(self):
|
| 199 |
+
self.netG.eval()
|
| 200 |
+
with torch.no_grad():
|
| 201 |
+
self.fake_H = self.netG((self.var_L, self.var_seg))
|
| 202 |
+
self.netG.train()
|
| 203 |
+
|
| 204 |
+
def get_current_log(self):
|
| 205 |
+
return self.log_dict
|
| 206 |
+
|
| 207 |
+
def get_current_visuals(self, need_HR=True):
|
| 208 |
+
out_dict = OrderedDict()
|
| 209 |
+
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
| 210 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
| 211 |
+
if need_HR:
|
| 212 |
+
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
| 213 |
+
return out_dict
|
| 214 |
+
|
| 215 |
+
def print_network(self):
|
| 216 |
+
# G
|
| 217 |
+
s, n = self.get_network_description(self.netG)
|
| 218 |
+
if isinstance(self.netG, nn.DataParallel):
|
| 219 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
| 220 |
+
self.netG.module.__class__.__name__)
|
| 221 |
+
else:
|
| 222 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
| 223 |
+
|
| 224 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 225 |
+
logger.info(s)
|
| 226 |
+
if self.is_train:
|
| 227 |
+
# D
|
| 228 |
+
s, n = self.get_network_description(self.netD)
|
| 229 |
+
if isinstance(self.netD, nn.DataParallel):
|
| 230 |
+
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
| 231 |
+
self.netD.module.__class__.__name__)
|
| 232 |
+
else:
|
| 233 |
+
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
| 234 |
+
|
| 235 |
+
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 236 |
+
logger.info(s)
|
| 237 |
+
|
| 238 |
+
if self.cri_fea: # F, Perceptual Network
|
| 239 |
+
s, n = self.get_network_description(self.netF)
|
| 240 |
+
if isinstance(self.netF, nn.DataParallel):
|
| 241 |
+
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
| 242 |
+
self.netF.module.__class__.__name__)
|
| 243 |
+
else:
|
| 244 |
+
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
| 245 |
+
|
| 246 |
+
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 247 |
+
logger.info(s)
|
| 248 |
+
|
| 249 |
+
def load(self):
|
| 250 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
| 251 |
+
if load_path_G is not None:
|
| 252 |
+
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
| 253 |
+
self.load_network(load_path_G, self.netG)
|
| 254 |
+
load_path_D = self.opt['path']['pretrain_model_D']
|
| 255 |
+
if self.opt['is_train'] and load_path_D is not None:
|
| 256 |
+
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
| 257 |
+
self.load_network(load_path_D, self.netD)
|
| 258 |
+
|
| 259 |
+
def save(self, iter_step):
|
| 260 |
+
self.save_network(self.netG, 'G', iter_step)
|
| 261 |
+
self.save_network(self.netD, 'D', iter_step)
|
esrgan_plus/codes/models/SRGAN_model.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import lr_scheduler
|
| 8 |
+
|
| 9 |
+
import models.networks as networks
|
| 10 |
+
from .base_model import BaseModel
|
| 11 |
+
from models.modules.loss import GANLoss, GradientPenaltyLoss
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger('base')
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SRGANModel(BaseModel):
|
| 17 |
+
def __init__(self, opt):
|
| 18 |
+
super(SRGANModel, self).__init__(opt)
|
| 19 |
+
train_opt = opt['train']
|
| 20 |
+
|
| 21 |
+
# define networks and load pretrained models
|
| 22 |
+
self.netG = networks.define_G(opt).to(self.device) # G
|
| 23 |
+
if self.is_train:
|
| 24 |
+
self.netD = networks.define_D(opt).to(self.device) # D
|
| 25 |
+
self.netG.train()
|
| 26 |
+
self.netD.train()
|
| 27 |
+
self.load() # load G and D if needed
|
| 28 |
+
|
| 29 |
+
# define losses, optimizer and scheduler
|
| 30 |
+
if self.is_train:
|
| 31 |
+
# G pixel loss
|
| 32 |
+
if train_opt['pixel_weight'] > 0:
|
| 33 |
+
l_pix_type = train_opt['pixel_criterion']
|
| 34 |
+
if l_pix_type == 'l1':
|
| 35 |
+
self.cri_pix = nn.L1Loss().to(self.device)
|
| 36 |
+
elif l_pix_type == 'l2':
|
| 37 |
+
self.cri_pix = nn.MSELoss().to(self.device)
|
| 38 |
+
else:
|
| 39 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
| 40 |
+
self.l_pix_w = train_opt['pixel_weight']
|
| 41 |
+
else:
|
| 42 |
+
logger.info('Remove pixel loss.')
|
| 43 |
+
self.cri_pix = None
|
| 44 |
+
|
| 45 |
+
# G feature loss
|
| 46 |
+
if train_opt['feature_weight'] > 0:
|
| 47 |
+
l_fea_type = train_opt['feature_criterion']
|
| 48 |
+
if l_fea_type == 'l1':
|
| 49 |
+
self.cri_fea = nn.L1Loss().to(self.device)
|
| 50 |
+
elif l_fea_type == 'l2':
|
| 51 |
+
self.cri_fea = nn.MSELoss().to(self.device)
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
| 54 |
+
self.l_fea_w = train_opt['feature_weight']
|
| 55 |
+
else:
|
| 56 |
+
logger.info('Remove feature loss.')
|
| 57 |
+
self.cri_fea = None
|
| 58 |
+
if self.cri_fea: # load VGG perceptual loss
|
| 59 |
+
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
| 60 |
+
|
| 61 |
+
# GD gan loss
|
| 62 |
+
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
| 63 |
+
self.l_gan_w = train_opt['gan_weight']
|
| 64 |
+
# D_update_ratio and D_init_iters are for WGAN
|
| 65 |
+
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
| 66 |
+
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
| 67 |
+
|
| 68 |
+
if train_opt['gan_type'] == 'wgan-gp':
|
| 69 |
+
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
| 70 |
+
# gradient penalty loss
|
| 71 |
+
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
| 72 |
+
self.l_gp_w = train_opt['gp_weigth']
|
| 73 |
+
|
| 74 |
+
# optimizers
|
| 75 |
+
# G
|
| 76 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
| 77 |
+
optim_params = []
|
| 78 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
| 79 |
+
if v.requires_grad:
|
| 80 |
+
optim_params.append(v)
|
| 81 |
+
else:
|
| 82 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
| 83 |
+
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
|
| 84 |
+
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
| 85 |
+
self.optimizers.append(self.optimizer_G)
|
| 86 |
+
# D
|
| 87 |
+
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
| 88 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
| 89 |
+
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
| 90 |
+
self.optimizers.append(self.optimizer_D)
|
| 91 |
+
|
| 92 |
+
# schedulers
|
| 93 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
| 94 |
+
for optimizer in self.optimizers:
|
| 95 |
+
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
| 96 |
+
train_opt['lr_steps'], train_opt['lr_gamma']))
|
| 97 |
+
else:
|
| 98 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
| 99 |
+
|
| 100 |
+
self.log_dict = OrderedDict()
|
| 101 |
+
# print network
|
| 102 |
+
self.print_network()
|
| 103 |
+
|
| 104 |
+
def feed_data(self, data, need_HR=True):
|
| 105 |
+
# LR
|
| 106 |
+
self.var_L = data['LR'].to(self.device)
|
| 107 |
+
if need_HR: # train or val
|
| 108 |
+
self.var_H = data['HR'].to(self.device)
|
| 109 |
+
|
| 110 |
+
input_ref = data['ref'] if 'ref' in data else data['HR']
|
| 111 |
+
self.var_ref = input_ref.to(self.device)
|
| 112 |
+
|
| 113 |
+
def optimize_parameters(self, step):
|
| 114 |
+
# G
|
| 115 |
+
self.optimizer_G.zero_grad()
|
| 116 |
+
self.fake_H = self.netG(self.var_L)
|
| 117 |
+
|
| 118 |
+
l_g_total = 0
|
| 119 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 120 |
+
if self.cri_pix: # pixel loss
|
| 121 |
+
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
| 122 |
+
l_g_total += l_g_pix
|
| 123 |
+
if self.cri_fea: # feature loss
|
| 124 |
+
real_fea = self.netF(self.var_H).detach()
|
| 125 |
+
fake_fea = self.netF(self.fake_H)
|
| 126 |
+
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
| 127 |
+
l_g_total += l_g_fea
|
| 128 |
+
# G gan + cls loss
|
| 129 |
+
pred_g_fake = self.netD(self.fake_H)
|
| 130 |
+
l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
|
| 131 |
+
l_g_total += l_g_gan
|
| 132 |
+
|
| 133 |
+
l_g_total.backward()
|
| 134 |
+
self.optimizer_G.step()
|
| 135 |
+
|
| 136 |
+
# D
|
| 137 |
+
self.optimizer_D.zero_grad()
|
| 138 |
+
l_d_total = 0
|
| 139 |
+
# real data
|
| 140 |
+
pred_d_real = self.netD(self.var_ref)
|
| 141 |
+
l_d_real = self.cri_gan(pred_d_real, True)
|
| 142 |
+
# fake data
|
| 143 |
+
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
| 144 |
+
l_d_fake = self.cri_gan(pred_d_fake, False)
|
| 145 |
+
|
| 146 |
+
l_d_total = l_d_real + l_d_fake
|
| 147 |
+
|
| 148 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 149 |
+
batch_size = self.var_ref.size(0)
|
| 150 |
+
if self.random_pt.size(0) != batch_size:
|
| 151 |
+
self.random_pt.resize_(batch_size, 1, 1, 1)
|
| 152 |
+
self.random_pt.uniform_() # Draw random interpolation points
|
| 153 |
+
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
|
| 154 |
+
interp.requires_grad = True
|
| 155 |
+
interp_crit = self.netD(interp)
|
| 156 |
+
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
|
| 157 |
+
l_d_total += l_d_gp
|
| 158 |
+
|
| 159 |
+
l_d_total.backward()
|
| 160 |
+
self.optimizer_D.step()
|
| 161 |
+
|
| 162 |
+
# set log
|
| 163 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 164 |
+
# G
|
| 165 |
+
if self.cri_pix:
|
| 166 |
+
self.log_dict['l_g_pix'] = l_g_pix.item()
|
| 167 |
+
if self.cri_fea:
|
| 168 |
+
self.log_dict['l_g_fea'] = l_g_fea.item()
|
| 169 |
+
self.log_dict['l_g_gan'] = l_g_gan.item()
|
| 170 |
+
# D
|
| 171 |
+
self.log_dict['l_d_real'] = l_d_real.item()
|
| 172 |
+
self.log_dict['l_d_fake'] = l_d_fake.item()
|
| 173 |
+
|
| 174 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 175 |
+
self.log_dict['l_d_gp'] = l_d_gp.item()
|
| 176 |
+
# D outputs
|
| 177 |
+
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
| 178 |
+
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
| 179 |
+
|
| 180 |
+
def test(self):
|
| 181 |
+
self.netG.eval()
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
self.fake_H = self.netG(self.var_L)
|
| 184 |
+
self.netG.train()
|
| 185 |
+
|
| 186 |
+
def get_current_log(self):
|
| 187 |
+
return self.log_dict
|
| 188 |
+
|
| 189 |
+
def get_current_visuals(self, need_HR=True):
|
| 190 |
+
out_dict = OrderedDict()
|
| 191 |
+
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
| 192 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
| 193 |
+
if need_HR:
|
| 194 |
+
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
| 195 |
+
return out_dict
|
| 196 |
+
|
| 197 |
+
def print_network(self):
|
| 198 |
+
# Generator
|
| 199 |
+
s, n = self.get_network_description(self.netG)
|
| 200 |
+
if isinstance(self.netG, nn.DataParallel):
|
| 201 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
| 202 |
+
self.netG.module.__class__.__name__)
|
| 203 |
+
else:
|
| 204 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
| 205 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 206 |
+
logger.info(s)
|
| 207 |
+
if self.is_train:
|
| 208 |
+
# Discriminator
|
| 209 |
+
s, n = self.get_network_description(self.netD)
|
| 210 |
+
if isinstance(self.netD, nn.DataParallel):
|
| 211 |
+
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
| 212 |
+
self.netD.module.__class__.__name__)
|
| 213 |
+
else:
|
| 214 |
+
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
| 215 |
+
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 216 |
+
logger.info(s)
|
| 217 |
+
|
| 218 |
+
if self.cri_fea: # F, Perceptual Network
|
| 219 |
+
s, n = self.get_network_description(self.netF)
|
| 220 |
+
if isinstance(self.netF, nn.DataParallel):
|
| 221 |
+
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
| 222 |
+
self.netF.module.__class__.__name__)
|
| 223 |
+
else:
|
| 224 |
+
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
| 225 |
+
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 226 |
+
logger.info(s)
|
| 227 |
+
|
| 228 |
+
def load(self):
|
| 229 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
| 230 |
+
if load_path_G is not None:
|
| 231 |
+
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
| 232 |
+
self.load_network(load_path_G, self.netG)
|
| 233 |
+
load_path_D = self.opt['path']['pretrain_model_D']
|
| 234 |
+
if self.opt['is_train'] and load_path_D is not None:
|
| 235 |
+
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
| 236 |
+
self.load_network(load_path_D, self.netD)
|
| 237 |
+
|
| 238 |
+
def save(self, iter_step):
|
| 239 |
+
self.save_network(self.netG, 'G', iter_step)
|
| 240 |
+
self.save_network(self.netD, 'D', iter_step)
|
esrgan_plus/codes/models/SRRaGAN_model.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import lr_scheduler
|
| 8 |
+
|
| 9 |
+
import models.networks as networks
|
| 10 |
+
from .base_model import BaseModel
|
| 11 |
+
from models.modules.loss import GANLoss, GradientPenaltyLoss
|
| 12 |
+
logger = logging.getLogger('base')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SRRaGANModel(BaseModel):
|
| 16 |
+
def __init__(self, opt):
|
| 17 |
+
super(SRRaGANModel, self).__init__(opt)
|
| 18 |
+
train_opt = opt['train']
|
| 19 |
+
|
| 20 |
+
# define networks and load pretrained models
|
| 21 |
+
self.netG = networks.define_G(opt).to(self.device) # G
|
| 22 |
+
if self.is_train:
|
| 23 |
+
self.netD = networks.define_D(opt).to(self.device) # D
|
| 24 |
+
self.netG.train()
|
| 25 |
+
self.netD.train()
|
| 26 |
+
self.load() # load G and D if needed
|
| 27 |
+
|
| 28 |
+
# define losses, optimizer and scheduler
|
| 29 |
+
if self.is_train:
|
| 30 |
+
# G pixel loss
|
| 31 |
+
if train_opt['pixel_weight'] > 0:
|
| 32 |
+
l_pix_type = train_opt['pixel_criterion']
|
| 33 |
+
if l_pix_type == 'l1':
|
| 34 |
+
self.cri_pix = nn.L1Loss().to(self.device)
|
| 35 |
+
elif l_pix_type == 'l2':
|
| 36 |
+
self.cri_pix = nn.MSELoss().to(self.device)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
|
| 39 |
+
self.l_pix_w = train_opt['pixel_weight']
|
| 40 |
+
else:
|
| 41 |
+
logger.info('Remove pixel loss.')
|
| 42 |
+
self.cri_pix = None
|
| 43 |
+
|
| 44 |
+
# G feature loss
|
| 45 |
+
if train_opt['feature_weight'] > 0:
|
| 46 |
+
l_fea_type = train_opt['feature_criterion']
|
| 47 |
+
if l_fea_type == 'l1':
|
| 48 |
+
self.cri_fea = nn.L1Loss().to(self.device)
|
| 49 |
+
elif l_fea_type == 'l2':
|
| 50 |
+
self.cri_fea = nn.MSELoss().to(self.device)
|
| 51 |
+
else:
|
| 52 |
+
raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
|
| 53 |
+
self.l_fea_w = train_opt['feature_weight']
|
| 54 |
+
else:
|
| 55 |
+
logger.info('Remove feature loss.')
|
| 56 |
+
self.cri_fea = None
|
| 57 |
+
if self.cri_fea: # load VGG perceptual loss
|
| 58 |
+
self.netF = networks.define_F(opt, use_bn=False).to(self.device)
|
| 59 |
+
|
| 60 |
+
# GD gan loss
|
| 61 |
+
self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
|
| 62 |
+
self.l_gan_w = train_opt['gan_weight']
|
| 63 |
+
# D_update_ratio and D_init_iters are for WGAN
|
| 64 |
+
self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
|
| 65 |
+
self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
|
| 66 |
+
|
| 67 |
+
if train_opt['gan_type'] == 'wgan-gp':
|
| 68 |
+
self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device)
|
| 69 |
+
# gradient penalty loss
|
| 70 |
+
self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device)
|
| 71 |
+
self.l_gp_w = train_opt['gp_weigth']
|
| 72 |
+
|
| 73 |
+
# optimizers
|
| 74 |
+
# G
|
| 75 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
| 76 |
+
optim_params = []
|
| 77 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
| 78 |
+
if v.requires_grad:
|
| 79 |
+
optim_params.append(v)
|
| 80 |
+
else:
|
| 81 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
| 82 |
+
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], \
|
| 83 |
+
weight_decay=wd_G, betas=(train_opt['beta1_G'], 0.999))
|
| 84 |
+
self.optimizers.append(self.optimizer_G)
|
| 85 |
+
# D
|
| 86 |
+
wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
|
| 87 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], \
|
| 88 |
+
weight_decay=wd_D, betas=(train_opt['beta1_D'], 0.999))
|
| 89 |
+
self.optimizers.append(self.optimizer_D)
|
| 90 |
+
|
| 91 |
+
# schedulers
|
| 92 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
| 93 |
+
for optimizer in self.optimizers:
|
| 94 |
+
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
| 95 |
+
train_opt['lr_steps'], train_opt['lr_gamma']))
|
| 96 |
+
else:
|
| 97 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
| 98 |
+
|
| 99 |
+
self.log_dict = OrderedDict()
|
| 100 |
+
# print network
|
| 101 |
+
self.print_network()
|
| 102 |
+
|
| 103 |
+
def feed_data(self, data, need_HR=True):
|
| 104 |
+
# LR
|
| 105 |
+
self.var_L = data['LR'].to(self.device)
|
| 106 |
+
|
| 107 |
+
if need_HR: # train or val
|
| 108 |
+
self.var_H = data['HR'].to(self.device)
|
| 109 |
+
|
| 110 |
+
input_ref = data['ref'] if 'ref' in data else data['HR']
|
| 111 |
+
self.var_ref = input_ref.to(self.device)
|
| 112 |
+
|
| 113 |
+
def optimize_parameters(self, step):
|
| 114 |
+
# G
|
| 115 |
+
for p in self.netD.parameters():
|
| 116 |
+
p.requires_grad = False
|
| 117 |
+
|
| 118 |
+
self.optimizer_G.zero_grad()
|
| 119 |
+
|
| 120 |
+
self.fake_H = self.netG(self.var_L)
|
| 121 |
+
|
| 122 |
+
l_g_total = 0
|
| 123 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 124 |
+
if self.cri_pix: # pixel loss
|
| 125 |
+
l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
|
| 126 |
+
l_g_total += l_g_pix
|
| 127 |
+
if self.cri_fea: # feature loss
|
| 128 |
+
real_fea = self.netF(self.var_H).detach()
|
| 129 |
+
fake_fea = self.netF(self.fake_H)
|
| 130 |
+
l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
|
| 131 |
+
l_g_total += l_g_fea
|
| 132 |
+
# G gan + cls loss
|
| 133 |
+
pred_g_fake = self.netD(self.fake_H)
|
| 134 |
+
pred_d_real = self.netD(self.var_ref).detach()
|
| 135 |
+
|
| 136 |
+
l_g_gan = self.l_gan_w * (self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
|
| 137 |
+
self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
|
| 138 |
+
l_g_total += l_g_gan
|
| 139 |
+
|
| 140 |
+
l_g_total.backward()
|
| 141 |
+
self.optimizer_G.step()
|
| 142 |
+
|
| 143 |
+
# D
|
| 144 |
+
for p in self.netD.parameters():
|
| 145 |
+
p.requires_grad = True
|
| 146 |
+
|
| 147 |
+
self.optimizer_D.zero_grad()
|
| 148 |
+
l_d_total = 0
|
| 149 |
+
pred_d_real = self.netD(self.var_ref)
|
| 150 |
+
pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
|
| 151 |
+
l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
|
| 152 |
+
l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
|
| 153 |
+
|
| 154 |
+
l_d_total = (l_d_real + l_d_fake) / 2
|
| 155 |
+
|
| 156 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 157 |
+
batch_size = self.var_ref.size(0)
|
| 158 |
+
if self.random_pt.size(0) != batch_size:
|
| 159 |
+
self.random_pt.resize_(batch_size, 1, 1, 1)
|
| 160 |
+
self.random_pt.uniform_() # Draw random interpolation points
|
| 161 |
+
interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_ref
|
| 162 |
+
interp.requires_grad = True
|
| 163 |
+
interp_crit = self.netD(interp)
|
| 164 |
+
l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit)
|
| 165 |
+
l_d_total += l_d_gp
|
| 166 |
+
|
| 167 |
+
l_d_total.backward()
|
| 168 |
+
self.optimizer_D.step()
|
| 169 |
+
|
| 170 |
+
# set log
|
| 171 |
+
if step % self.D_update_ratio == 0 and step > self.D_init_iters:
|
| 172 |
+
# G
|
| 173 |
+
if self.cri_pix:
|
| 174 |
+
self.log_dict['l_g_pix'] = l_g_pix.item()
|
| 175 |
+
if self.cri_fea:
|
| 176 |
+
self.log_dict['l_g_fea'] = l_g_fea.item()
|
| 177 |
+
self.log_dict['l_g_gan'] = l_g_gan.item()
|
| 178 |
+
# D
|
| 179 |
+
self.log_dict['l_d_real'] = l_d_real.item()
|
| 180 |
+
self.log_dict['l_d_fake'] = l_d_fake.item()
|
| 181 |
+
|
| 182 |
+
if self.opt['train']['gan_type'] == 'wgan-gp':
|
| 183 |
+
self.log_dict['l_d_gp'] = l_d_gp.item()
|
| 184 |
+
# D outputs
|
| 185 |
+
self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
|
| 186 |
+
self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
|
| 187 |
+
|
| 188 |
+
def test(self):
|
| 189 |
+
self.netG.eval()
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
self.fake_H = self.netG(self.var_L)
|
| 192 |
+
self.netG.train()
|
| 193 |
+
|
| 194 |
+
def get_current_log(self):
|
| 195 |
+
return self.log_dict
|
| 196 |
+
|
| 197 |
+
def get_current_visuals(self, need_HR=True):
|
| 198 |
+
out_dict = OrderedDict()
|
| 199 |
+
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
| 200 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
| 201 |
+
if need_HR:
|
| 202 |
+
out_dict['HR'] = self.var_H.detach()[0].float().cpu()
|
| 203 |
+
return out_dict
|
| 204 |
+
|
| 205 |
+
def print_network(self):
|
| 206 |
+
# Generator
|
| 207 |
+
s, n = self.get_network_description(self.netG)
|
| 208 |
+
if isinstance(self.netG, nn.DataParallel):
|
| 209 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
| 210 |
+
self.netG.module.__class__.__name__)
|
| 211 |
+
else:
|
| 212 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
| 213 |
+
|
| 214 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 215 |
+
logger.info(s)
|
| 216 |
+
if self.is_train:
|
| 217 |
+
# Discriminator
|
| 218 |
+
s, n = self.get_network_description(self.netD)
|
| 219 |
+
if isinstance(self.netD, nn.DataParallel):
|
| 220 |
+
net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
|
| 221 |
+
self.netD.module.__class__.__name__)
|
| 222 |
+
else:
|
| 223 |
+
net_struc_str = '{}'.format(self.netD.__class__.__name__)
|
| 224 |
+
|
| 225 |
+
logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 226 |
+
logger.info(s)
|
| 227 |
+
|
| 228 |
+
if self.cri_fea: # F, Perceptual Network
|
| 229 |
+
s, n = self.get_network_description(self.netF)
|
| 230 |
+
if isinstance(self.netF, nn.DataParallel):
|
| 231 |
+
net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
|
| 232 |
+
self.netF.module.__class__.__name__)
|
| 233 |
+
else:
|
| 234 |
+
net_struc_str = '{}'.format(self.netF.__class__.__name__)
|
| 235 |
+
|
| 236 |
+
logger.info('Network F structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 237 |
+
logger.info(s)
|
| 238 |
+
|
| 239 |
+
def load(self):
|
| 240 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
| 241 |
+
if load_path_G is not None:
|
| 242 |
+
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
| 243 |
+
self.load_network(load_path_G, self.netG)
|
| 244 |
+
load_path_D = self.opt['path']['pretrain_model_D']
|
| 245 |
+
if self.opt['is_train'] and load_path_D is not None:
|
| 246 |
+
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
|
| 247 |
+
self.load_network(load_path_D, self.netD)
|
| 248 |
+
|
| 249 |
+
def save(self, iter_step):
|
| 250 |
+
self.save_network(self.netG, 'G', iter_step)
|
| 251 |
+
self.save_network(self.netD, 'D', iter_step)
|
esrgan_plus/codes/models/SR_model.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from collections import OrderedDict
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.optim import lr_scheduler
|
| 8 |
+
|
| 9 |
+
import models.networks as networks
|
| 10 |
+
from .base_model import BaseModel
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger('base')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SRModel(BaseModel):
|
| 16 |
+
def __init__(self, opt):
|
| 17 |
+
super(SRModel, self).__init__(opt)
|
| 18 |
+
train_opt = opt['train']
|
| 19 |
+
|
| 20 |
+
# define network and load pretrained models
|
| 21 |
+
self.netG = networks.define_G(opt).to(self.device)
|
| 22 |
+
self.load()
|
| 23 |
+
|
| 24 |
+
if self.is_train:
|
| 25 |
+
self.netG.train()
|
| 26 |
+
|
| 27 |
+
# loss
|
| 28 |
+
loss_type = train_opt['pixel_criterion']
|
| 29 |
+
if loss_type == 'l1':
|
| 30 |
+
self.cri_pix = nn.L1Loss().to(self.device)
|
| 31 |
+
elif loss_type == 'l2':
|
| 32 |
+
self.cri_pix = nn.MSELoss().to(self.device)
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
|
| 35 |
+
self.l_pix_w = train_opt['pixel_weight']
|
| 36 |
+
|
| 37 |
+
# optimizers
|
| 38 |
+
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
|
| 39 |
+
optim_params = []
|
| 40 |
+
for k, v in self.netG.named_parameters(): # can optimize for a part of the model
|
| 41 |
+
if v.requires_grad:
|
| 42 |
+
optim_params.append(v)
|
| 43 |
+
else:
|
| 44 |
+
logger.warning('Params [{:s}] will not optimize.'.format(k))
|
| 45 |
+
self.optimizer_G = torch.optim.Adam(
|
| 46 |
+
optim_params, lr=train_opt['lr_G'], weight_decay=wd_G)
|
| 47 |
+
self.optimizers.append(self.optimizer_G)
|
| 48 |
+
|
| 49 |
+
# schedulers
|
| 50 |
+
if train_opt['lr_scheme'] == 'MultiStepLR':
|
| 51 |
+
for optimizer in self.optimizers:
|
| 52 |
+
self.schedulers.append(lr_scheduler.MultiStepLR(optimizer, \
|
| 53 |
+
train_opt['lr_steps'], train_opt['lr_gamma']))
|
| 54 |
+
else:
|
| 55 |
+
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
|
| 56 |
+
|
| 57 |
+
self.log_dict = OrderedDict()
|
| 58 |
+
# print network
|
| 59 |
+
self.print_network()
|
| 60 |
+
|
| 61 |
+
def feed_data(self, data, need_HR=True):
|
| 62 |
+
self.var_L = data['LR'].to(self.device) # LR
|
| 63 |
+
if need_HR:
|
| 64 |
+
self.real_H = data['HR'].to(self.device) # HR
|
| 65 |
+
|
| 66 |
+
def optimize_parameters(self, step):
|
| 67 |
+
self.optimizer_G.zero_grad()
|
| 68 |
+
self.fake_H = self.netG(self.var_L)
|
| 69 |
+
l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
|
| 70 |
+
l_pix.backward()
|
| 71 |
+
self.optimizer_G.step()
|
| 72 |
+
|
| 73 |
+
# set log
|
| 74 |
+
self.log_dict['l_pix'] = l_pix.item()
|
| 75 |
+
|
| 76 |
+
def test(self):
|
| 77 |
+
self.netG.eval()
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
self.fake_H = self.netG(self.var_L)
|
| 80 |
+
self.netG.train()
|
| 81 |
+
|
| 82 |
+
def test_x8(self):
|
| 83 |
+
# from https://github.com/thstkdgus35/EDSR-PyTorch
|
| 84 |
+
self.netG.eval()
|
| 85 |
+
for k, v in self.netG.named_parameters():
|
| 86 |
+
v.requires_grad = False
|
| 87 |
+
|
| 88 |
+
def _transform(v, op):
|
| 89 |
+
# if self.precision != 'single': v = v.float()
|
| 90 |
+
v2np = v.data.cpu().numpy()
|
| 91 |
+
if op == 'v':
|
| 92 |
+
tfnp = v2np[:, :, :, ::-1].copy()
|
| 93 |
+
elif op == 'h':
|
| 94 |
+
tfnp = v2np[:, :, ::-1, :].copy()
|
| 95 |
+
elif op == 't':
|
| 96 |
+
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
|
| 97 |
+
|
| 98 |
+
ret = torch.Tensor(tfnp).to(self.device)
|
| 99 |
+
# if self.precision == 'half': ret = ret.half()
|
| 100 |
+
|
| 101 |
+
return ret
|
| 102 |
+
|
| 103 |
+
lr_list = [self.var_L]
|
| 104 |
+
for tf in 'v', 'h', 't':
|
| 105 |
+
lr_list.extend([_transform(t, tf) for t in lr_list])
|
| 106 |
+
sr_list = [self.netG(aug) for aug in lr_list]
|
| 107 |
+
for i in range(len(sr_list)):
|
| 108 |
+
if i > 3:
|
| 109 |
+
sr_list[i] = _transform(sr_list[i], 't')
|
| 110 |
+
if i % 4 > 1:
|
| 111 |
+
sr_list[i] = _transform(sr_list[i], 'h')
|
| 112 |
+
if (i % 4) % 2 == 1:
|
| 113 |
+
sr_list[i] = _transform(sr_list[i], 'v')
|
| 114 |
+
|
| 115 |
+
output_cat = torch.cat(sr_list, dim=0)
|
| 116 |
+
self.fake_H = output_cat.mean(dim=0, keepdim=True)
|
| 117 |
+
|
| 118 |
+
for k, v in self.netG.named_parameters():
|
| 119 |
+
v.requires_grad = True
|
| 120 |
+
self.netG.train()
|
| 121 |
+
|
| 122 |
+
def get_current_log(self):
|
| 123 |
+
return self.log_dict
|
| 124 |
+
|
| 125 |
+
def get_current_visuals(self, need_HR=True):
|
| 126 |
+
out_dict = OrderedDict()
|
| 127 |
+
out_dict['LR'] = self.var_L.detach()[0].float().cpu()
|
| 128 |
+
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
|
| 129 |
+
if need_HR:
|
| 130 |
+
out_dict['HR'] = self.real_H.detach()[0].float().cpu()
|
| 131 |
+
return out_dict
|
| 132 |
+
|
| 133 |
+
def print_network(self):
|
| 134 |
+
s, n = self.get_network_description(self.netG)
|
| 135 |
+
if isinstance(self.netG, nn.DataParallel):
|
| 136 |
+
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
|
| 137 |
+
self.netG.module.__class__.__name__)
|
| 138 |
+
else:
|
| 139 |
+
net_struc_str = '{}'.format(self.netG.__class__.__name__)
|
| 140 |
+
|
| 141 |
+
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
|
| 142 |
+
logger.info(s)
|
| 143 |
+
|
| 144 |
+
def load(self):
|
| 145 |
+
load_path_G = self.opt['path']['pretrain_model_G']
|
| 146 |
+
if load_path_G is not None:
|
| 147 |
+
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
|
| 148 |
+
self.load_network(load_path_G, self.netG)
|
| 149 |
+
|
| 150 |
+
def save(self, iter_step):
|
| 151 |
+
self.save_network(self.netG, 'G', iter_step)
|
esrgan_plus/codes/models/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
logger = logging.getLogger('base')
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_model(opt):
|
| 6 |
+
model = opt['model']
|
| 7 |
+
|
| 8 |
+
if model == 'sr':
|
| 9 |
+
from .SR_model import SRModel as M
|
| 10 |
+
elif model == 'srgan':
|
| 11 |
+
from .SRGAN_model import SRGANModel as M
|
| 12 |
+
elif model == 'srragan':
|
| 13 |
+
from .SRRaGAN_model import SRRaGANModel as M
|
| 14 |
+
elif model == 'sftgan':
|
| 15 |
+
from .SFTGAN_ACD_model import SFTGAN_ACD_Model as M
|
| 16 |
+
else:
|
| 17 |
+
raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
|
| 18 |
+
m = M(opt)
|
| 19 |
+
logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
|
| 20 |
+
return m
|
esrgan_plus/codes/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (807 Bytes). View file
|
|
|
esrgan_plus/codes/models/base_model.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class BaseModel():
|
| 7 |
+
def __init__(self, opt):
|
| 8 |
+
self.opt = opt
|
| 9 |
+
self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
|
| 10 |
+
self.is_train = opt['is_train']
|
| 11 |
+
self.schedulers = []
|
| 12 |
+
self.optimizers = []
|
| 13 |
+
|
| 14 |
+
def feed_data(self, data):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def optimize_parameters(self):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def get_current_visuals(self):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def get_current_losses(self):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def print_network(self):
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
def save(self, label):
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
def load(self):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def update_learning_rate(self):
|
| 36 |
+
for scheduler in self.schedulers:
|
| 37 |
+
scheduler.step()
|
| 38 |
+
|
| 39 |
+
def get_current_learning_rate(self):
|
| 40 |
+
return self.schedulers[0].get_lr()[0]
|
| 41 |
+
|
| 42 |
+
def get_network_description(self, network):
|
| 43 |
+
'''Get the string and total parameters of the network'''
|
| 44 |
+
if isinstance(network, nn.DataParallel):
|
| 45 |
+
network = network.module
|
| 46 |
+
s = str(network)
|
| 47 |
+
n = sum(map(lambda x: x.numel(), network.parameters()))
|
| 48 |
+
return s, n
|
| 49 |
+
|
| 50 |
+
def save_network(self, network, network_label, iter_step):
|
| 51 |
+
save_filename = '{}_{}.pth'.format(iter_step, network_label)
|
| 52 |
+
save_path = os.path.join(self.opt['path']['models'], save_filename)
|
| 53 |
+
if isinstance(network, nn.DataParallel):
|
| 54 |
+
network = network.module
|
| 55 |
+
state_dict = network.state_dict()
|
| 56 |
+
for key, param in state_dict.items():
|
| 57 |
+
state_dict[key] = param.cpu()
|
| 58 |
+
torch.save(state_dict, save_path)
|
| 59 |
+
|
| 60 |
+
def load_network(self, load_path, network, strict=True):
|
| 61 |
+
if isinstance(network, nn.DataParallel):
|
| 62 |
+
network = network.module
|
| 63 |
+
network.load_state_dict(torch.load(load_path), strict=strict)
|
| 64 |
+
|
| 65 |
+
def save_training_state(self, epoch, iter_step):
|
| 66 |
+
'''Saves training state during training, which will be used for resuming'''
|
| 67 |
+
state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
|
| 68 |
+
for s in self.schedulers:
|
| 69 |
+
state['schedulers'].append(s.state_dict())
|
| 70 |
+
for o in self.optimizers:
|
| 71 |
+
state['optimizers'].append(o.state_dict())
|
| 72 |
+
save_filename = '{}.state'.format(iter_step)
|
| 73 |
+
save_path = os.path.join(self.opt['path']['training_state'], save_filename)
|
| 74 |
+
torch.save(state, save_path)
|
| 75 |
+
|
| 76 |
+
def resume_training(self, resume_state):
|
| 77 |
+
'''Resume the optimizers and schedulers for training'''
|
| 78 |
+
resume_optimizers = resume_state['optimizers']
|
| 79 |
+
resume_schedulers = resume_state['schedulers']
|
| 80 |
+
assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
|
| 81 |
+
assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
|
| 82 |
+
for i, o in enumerate(resume_optimizers):
|
| 83 |
+
self.optimizers[i].load_state_dict(o)
|
| 84 |
+
for i, s in enumerate(resume_schedulers):
|
| 85 |
+
self.schedulers[i].load_state_dict(s)
|
esrgan_plus/codes/models/modules/__pycache__/architecture.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
esrgan_plus/codes/models/modules/__pycache__/block.cpython-310.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
esrgan_plus/codes/models/modules/__pycache__/spectral_norm.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
esrgan_plus/codes/models/modules/architecture.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision
|
| 5 |
+
from . import block as B
|
| 6 |
+
from . import spectral_norm as SN
|
| 7 |
+
|
| 8 |
+
####################
|
| 9 |
+
# Generator
|
| 10 |
+
####################
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SRResNet(nn.Module):
|
| 14 |
+
def __init__(self, in_nc, out_nc, nf, nb, upscale=4, norm_type='batch', act_type='relu', \
|
| 15 |
+
mode='NAC', res_scale=1, upsample_mode='upconv'):
|
| 16 |
+
super(SRResNet, self).__init__()
|
| 17 |
+
n_upscale = int(math.log(upscale, 2))
|
| 18 |
+
if upscale == 3:
|
| 19 |
+
n_upscale = 1
|
| 20 |
+
|
| 21 |
+
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
| 22 |
+
resnet_blocks = [B.ResNetBlock(nf, nf, nf, norm_type=norm_type, act_type=act_type,\
|
| 23 |
+
mode=mode, res_scale=res_scale) for _ in range(nb)]
|
| 24 |
+
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
| 25 |
+
|
| 26 |
+
if upsample_mode == 'upconv':
|
| 27 |
+
upsample_block = B.upconv_blcok
|
| 28 |
+
elif upsample_mode == 'pixelshuffle':
|
| 29 |
+
upsample_block = B.pixelshuffle_block
|
| 30 |
+
else:
|
| 31 |
+
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
| 32 |
+
if upscale == 3:
|
| 33 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
| 34 |
+
else:
|
| 35 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
| 36 |
+
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
| 37 |
+
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
| 38 |
+
|
| 39 |
+
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*resnet_blocks, LR_conv)),\
|
| 40 |
+
*upsampler, HR_conv0, HR_conv1)
|
| 41 |
+
|
| 42 |
+
def forward(self, x):
|
| 43 |
+
x = self.model(x)
|
| 44 |
+
return x
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RRDBNet(nn.Module):
|
| 48 |
+
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, \
|
| 49 |
+
act_type='leakyrelu', mode='CNA', upsample_mode='upconv'):
|
| 50 |
+
super(RRDBNet, self).__init__()
|
| 51 |
+
n_upscale = int(math.log(upscale, 2))
|
| 52 |
+
if upscale == 3:
|
| 53 |
+
n_upscale = 1
|
| 54 |
+
|
| 55 |
+
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
| 56 |
+
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
| 57 |
+
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
| 58 |
+
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
| 59 |
+
|
| 60 |
+
if upsample_mode == 'upconv':
|
| 61 |
+
upsample_block = B.upconv_blcok
|
| 62 |
+
elif upsample_mode == 'pixelshuffle':
|
| 63 |
+
upsample_block = B.pixelshuffle_block
|
| 64 |
+
else:
|
| 65 |
+
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
|
| 66 |
+
if upscale == 3:
|
| 67 |
+
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
| 68 |
+
else:
|
| 69 |
+
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
| 70 |
+
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
| 71 |
+
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
| 72 |
+
|
| 73 |
+
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
| 74 |
+
*upsampler, HR_conv0, HR_conv1)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
x = self.model(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
####################
|
| 82 |
+
# Discriminator
|
| 83 |
+
####################
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# VGG style Discriminator with input size 128*128
|
| 87 |
+
class Discriminator_VGG_128(nn.Module):
|
| 88 |
+
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
| 89 |
+
super(Discriminator_VGG_128, self).__init__()
|
| 90 |
+
# features
|
| 91 |
+
# hxw, c
|
| 92 |
+
# 128, 64
|
| 93 |
+
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
| 94 |
+
mode=mode)
|
| 95 |
+
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 96 |
+
act_type=act_type, mode=mode)
|
| 97 |
+
# 64, 64
|
| 98 |
+
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 99 |
+
act_type=act_type, mode=mode)
|
| 100 |
+
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 101 |
+
act_type=act_type, mode=mode)
|
| 102 |
+
# 32, 128
|
| 103 |
+
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 104 |
+
act_type=act_type, mode=mode)
|
| 105 |
+
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 106 |
+
act_type=act_type, mode=mode)
|
| 107 |
+
# 16, 256
|
| 108 |
+
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 109 |
+
act_type=act_type, mode=mode)
|
| 110 |
+
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 111 |
+
act_type=act_type, mode=mode)
|
| 112 |
+
# 8, 512
|
| 113 |
+
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 114 |
+
act_type=act_type, mode=mode)
|
| 115 |
+
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 116 |
+
act_type=act_type, mode=mode)
|
| 117 |
+
# 4, 512
|
| 118 |
+
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
| 119 |
+
conv9)
|
| 120 |
+
|
| 121 |
+
# classifier
|
| 122 |
+
self.classifier = nn.Sequential(
|
| 123 |
+
nn.Linear(512 * 4 * 4, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.features(x)
|
| 127 |
+
x = x.view(x.size(0), -1)
|
| 128 |
+
x = self.classifier(x)
|
| 129 |
+
return x
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# VGG style Discriminator with input size 128*128, Spectral Normalization
|
| 133 |
+
class Discriminator_VGG_128_SN(nn.Module):
|
| 134 |
+
def __init__(self):
|
| 135 |
+
super(Discriminator_VGG_128_SN, self).__init__()
|
| 136 |
+
# features
|
| 137 |
+
# hxw, c
|
| 138 |
+
# 128, 64
|
| 139 |
+
self.lrelu = nn.LeakyReLU(0.2, True)
|
| 140 |
+
|
| 141 |
+
self.conv0 = SN.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
|
| 142 |
+
self.conv1 = SN.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))
|
| 143 |
+
# 64, 64
|
| 144 |
+
self.conv2 = SN.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
|
| 145 |
+
self.conv3 = SN.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))
|
| 146 |
+
# 32, 128
|
| 147 |
+
self.conv4 = SN.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
|
| 148 |
+
self.conv5 = SN.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))
|
| 149 |
+
# 16, 256
|
| 150 |
+
self.conv6 = SN.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))
|
| 151 |
+
self.conv7 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
| 152 |
+
# 8, 512
|
| 153 |
+
self.conv8 = SN.spectral_norm(nn.Conv2d(512, 512, 3, 1, 1))
|
| 154 |
+
self.conv9 = SN.spectral_norm(nn.Conv2d(512, 512, 4, 2, 1))
|
| 155 |
+
# 4, 512
|
| 156 |
+
|
| 157 |
+
# classifier
|
| 158 |
+
self.linear0 = SN.spectral_norm(nn.Linear(512 * 4 * 4, 100))
|
| 159 |
+
self.linear1 = SN.spectral_norm(nn.Linear(100, 1))
|
| 160 |
+
|
| 161 |
+
def forward(self, x):
|
| 162 |
+
x = self.lrelu(self.conv0(x))
|
| 163 |
+
x = self.lrelu(self.conv1(x))
|
| 164 |
+
x = self.lrelu(self.conv2(x))
|
| 165 |
+
x = self.lrelu(self.conv3(x))
|
| 166 |
+
x = self.lrelu(self.conv4(x))
|
| 167 |
+
x = self.lrelu(self.conv5(x))
|
| 168 |
+
x = self.lrelu(self.conv6(x))
|
| 169 |
+
x = self.lrelu(self.conv7(x))
|
| 170 |
+
x = self.lrelu(self.conv8(x))
|
| 171 |
+
x = self.lrelu(self.conv9(x))
|
| 172 |
+
x = x.view(x.size(0), -1)
|
| 173 |
+
x = self.lrelu(self.linear0(x))
|
| 174 |
+
x = self.linear1(x)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class Discriminator_VGG_96(nn.Module):
|
| 179 |
+
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
| 180 |
+
super(Discriminator_VGG_96, self).__init__()
|
| 181 |
+
# features
|
| 182 |
+
# hxw, c
|
| 183 |
+
# 96, 64
|
| 184 |
+
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
| 185 |
+
mode=mode)
|
| 186 |
+
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 187 |
+
act_type=act_type, mode=mode)
|
| 188 |
+
# 48, 64
|
| 189 |
+
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 190 |
+
act_type=act_type, mode=mode)
|
| 191 |
+
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 192 |
+
act_type=act_type, mode=mode)
|
| 193 |
+
# 24, 128
|
| 194 |
+
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 195 |
+
act_type=act_type, mode=mode)
|
| 196 |
+
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 197 |
+
act_type=act_type, mode=mode)
|
| 198 |
+
# 12, 256
|
| 199 |
+
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 200 |
+
act_type=act_type, mode=mode)
|
| 201 |
+
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 202 |
+
act_type=act_type, mode=mode)
|
| 203 |
+
# 6, 512
|
| 204 |
+
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 205 |
+
act_type=act_type, mode=mode)
|
| 206 |
+
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 207 |
+
act_type=act_type, mode=mode)
|
| 208 |
+
# 3, 512
|
| 209 |
+
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
| 210 |
+
conv9)
|
| 211 |
+
|
| 212 |
+
# classifier
|
| 213 |
+
self.classifier = nn.Sequential(
|
| 214 |
+
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
| 215 |
+
|
| 216 |
+
def forward(self, x):
|
| 217 |
+
x = self.features(x)
|
| 218 |
+
x = x.view(x.size(0), -1)
|
| 219 |
+
x = self.classifier(x)
|
| 220 |
+
return x
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class Discriminator_VGG_192(nn.Module):
|
| 224 |
+
def __init__(self, in_nc, base_nf, norm_type='batch', act_type='leakyrelu', mode='CNA'):
|
| 225 |
+
super(Discriminator_VGG_192, self).__init__()
|
| 226 |
+
# features
|
| 227 |
+
# hxw, c
|
| 228 |
+
# 192, 64
|
| 229 |
+
conv0 = B.conv_block(in_nc, base_nf, kernel_size=3, norm_type=None, act_type=act_type, \
|
| 230 |
+
mode=mode)
|
| 231 |
+
conv1 = B.conv_block(base_nf, base_nf, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 232 |
+
act_type=act_type, mode=mode)
|
| 233 |
+
# 96, 64
|
| 234 |
+
conv2 = B.conv_block(base_nf, base_nf*2, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 235 |
+
act_type=act_type, mode=mode)
|
| 236 |
+
conv3 = B.conv_block(base_nf*2, base_nf*2, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 237 |
+
act_type=act_type, mode=mode)
|
| 238 |
+
# 48, 128
|
| 239 |
+
conv4 = B.conv_block(base_nf*2, base_nf*4, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 240 |
+
act_type=act_type, mode=mode)
|
| 241 |
+
conv5 = B.conv_block(base_nf*4, base_nf*4, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 242 |
+
act_type=act_type, mode=mode)
|
| 243 |
+
# 24, 256
|
| 244 |
+
conv6 = B.conv_block(base_nf*4, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 245 |
+
act_type=act_type, mode=mode)
|
| 246 |
+
conv7 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 247 |
+
act_type=act_type, mode=mode)
|
| 248 |
+
# 12, 512
|
| 249 |
+
conv8 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 250 |
+
act_type=act_type, mode=mode)
|
| 251 |
+
conv9 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 252 |
+
act_type=act_type, mode=mode)
|
| 253 |
+
# 6, 512
|
| 254 |
+
conv10 = B.conv_block(base_nf*8, base_nf*8, kernel_size=3, stride=1, norm_type=norm_type, \
|
| 255 |
+
act_type=act_type, mode=mode)
|
| 256 |
+
conv11 = B.conv_block(base_nf*8, base_nf*8, kernel_size=4, stride=2, norm_type=norm_type, \
|
| 257 |
+
act_type=act_type, mode=mode)
|
| 258 |
+
# 3, 512
|
| 259 |
+
self.features = B.sequential(conv0, conv1, conv2, conv3, conv4, conv5, conv6, conv7, conv8,\
|
| 260 |
+
conv9, conv10, conv11)
|
| 261 |
+
|
| 262 |
+
# classifier
|
| 263 |
+
self.classifier = nn.Sequential(
|
| 264 |
+
nn.Linear(512 * 3 * 3, 100), nn.LeakyReLU(0.2, True), nn.Linear(100, 1))
|
| 265 |
+
|
| 266 |
+
def forward(self, x):
|
| 267 |
+
x = self.features(x)
|
| 268 |
+
x = x.view(x.size(0), -1)
|
| 269 |
+
x = self.classifier(x)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
####################
|
| 274 |
+
# Perceptual Network
|
| 275 |
+
####################
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
# Assume input range is [0, 1]
|
| 279 |
+
class VGGFeatureExtractor(nn.Module):
|
| 280 |
+
def __init__(self,
|
| 281 |
+
feature_layer=34,
|
| 282 |
+
use_bn=False,
|
| 283 |
+
use_input_norm=True,
|
| 284 |
+
device=torch.device('cpu')):
|
| 285 |
+
super(VGGFeatureExtractor, self).__init__()
|
| 286 |
+
if use_bn:
|
| 287 |
+
model = torchvision.models.vgg19_bn(pretrained=True)
|
| 288 |
+
else:
|
| 289 |
+
model = torchvision.models.vgg19(pretrained=True)
|
| 290 |
+
self.use_input_norm = use_input_norm
|
| 291 |
+
if self.use_input_norm:
|
| 292 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
| 293 |
+
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
| 294 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
| 295 |
+
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
| 296 |
+
self.register_buffer('mean', mean)
|
| 297 |
+
self.register_buffer('std', std)
|
| 298 |
+
self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
|
| 299 |
+
# No need to BP to variable
|
| 300 |
+
for k, v in self.features.named_parameters():
|
| 301 |
+
v.requires_grad = False
|
| 302 |
+
|
| 303 |
+
def forward(self, x):
|
| 304 |
+
if self.use_input_norm:
|
| 305 |
+
x = (x - self.mean) / self.std
|
| 306 |
+
output = self.features(x)
|
| 307 |
+
return output
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
# Assume input range is [0, 1]
|
| 311 |
+
class ResNet101FeatureExtractor(nn.Module):
|
| 312 |
+
def __init__(self, use_input_norm=True, device=torch.device('cpu')):
|
| 313 |
+
super(ResNet101FeatureExtractor, self).__init__()
|
| 314 |
+
model = torchvision.models.resnet101(pretrained=True)
|
| 315 |
+
self.use_input_norm = use_input_norm
|
| 316 |
+
if self.use_input_norm:
|
| 317 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
|
| 318 |
+
# [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1]
|
| 319 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
|
| 320 |
+
# [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1]
|
| 321 |
+
self.register_buffer('mean', mean)
|
| 322 |
+
self.register_buffer('std', std)
|
| 323 |
+
self.features = nn.Sequential(*list(model.children())[:8])
|
| 324 |
+
# No need to BP to variable
|
| 325 |
+
for k, v in self.features.named_parameters():
|
| 326 |
+
v.requires_grad = False
|
| 327 |
+
|
| 328 |
+
def forward(self, x):
|
| 329 |
+
if self.use_input_norm:
|
| 330 |
+
x = (x - self.mean) / self.std
|
| 331 |
+
output = self.features(x)
|
| 332 |
+
return output
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
class MINCNet(nn.Module):
|
| 336 |
+
def __init__(self):
|
| 337 |
+
super(MINCNet, self).__init__()
|
| 338 |
+
self.ReLU = nn.ReLU(True)
|
| 339 |
+
self.conv11 = nn.Conv2d(3, 64, 3, 1, 1)
|
| 340 |
+
self.conv12 = nn.Conv2d(64, 64, 3, 1, 1)
|
| 341 |
+
self.maxpool1 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
| 342 |
+
self.conv21 = nn.Conv2d(64, 128, 3, 1, 1)
|
| 343 |
+
self.conv22 = nn.Conv2d(128, 128, 3, 1, 1)
|
| 344 |
+
self.maxpool2 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
| 345 |
+
self.conv31 = nn.Conv2d(128, 256, 3, 1, 1)
|
| 346 |
+
self.conv32 = nn.Conv2d(256, 256, 3, 1, 1)
|
| 347 |
+
self.conv33 = nn.Conv2d(256, 256, 3, 1, 1)
|
| 348 |
+
self.maxpool3 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
| 349 |
+
self.conv41 = nn.Conv2d(256, 512, 3, 1, 1)
|
| 350 |
+
self.conv42 = nn.Conv2d(512, 512, 3, 1, 1)
|
| 351 |
+
self.conv43 = nn.Conv2d(512, 512, 3, 1, 1)
|
| 352 |
+
self.maxpool4 = nn.MaxPool2d(2, stride=2, padding=0, ceil_mode=True)
|
| 353 |
+
self.conv51 = nn.Conv2d(512, 512, 3, 1, 1)
|
| 354 |
+
self.conv52 = nn.Conv2d(512, 512, 3, 1, 1)
|
| 355 |
+
self.conv53 = nn.Conv2d(512, 512, 3, 1, 1)
|
| 356 |
+
|
| 357 |
+
def forward(self, x):
|
| 358 |
+
out = self.ReLU(self.conv11(x))
|
| 359 |
+
out = self.ReLU(self.conv12(out))
|
| 360 |
+
out = self.maxpool1(out)
|
| 361 |
+
out = self.ReLU(self.conv21(out))
|
| 362 |
+
out = self.ReLU(self.conv22(out))
|
| 363 |
+
out = self.maxpool2(out)
|
| 364 |
+
out = self.ReLU(self.conv31(out))
|
| 365 |
+
out = self.ReLU(self.conv32(out))
|
| 366 |
+
out = self.ReLU(self.conv33(out))
|
| 367 |
+
out = self.maxpool3(out)
|
| 368 |
+
out = self.ReLU(self.conv41(out))
|
| 369 |
+
out = self.ReLU(self.conv42(out))
|
| 370 |
+
out = self.ReLU(self.conv43(out))
|
| 371 |
+
out = self.maxpool4(out)
|
| 372 |
+
out = self.ReLU(self.conv51(out))
|
| 373 |
+
out = self.ReLU(self.conv52(out))
|
| 374 |
+
out = self.conv53(out)
|
| 375 |
+
return out
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Assume input range is [0, 1]
|
| 379 |
+
class MINCFeatureExtractor(nn.Module):
|
| 380 |
+
def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, \
|
| 381 |
+
device=torch.device('cpu')):
|
| 382 |
+
super(MINCFeatureExtractor, self).__init__()
|
| 383 |
+
|
| 384 |
+
self.features = MINCNet()
|
| 385 |
+
self.features.load_state_dict(
|
| 386 |
+
torch.load('../experiments/pretrained_models/VGG16minc_53.pth'), strict=True)
|
| 387 |
+
self.features.eval()
|
| 388 |
+
# No need to BP to variable
|
| 389 |
+
for k, v in self.features.named_parameters():
|
| 390 |
+
v.requires_grad = False
|
| 391 |
+
|
| 392 |
+
def forward(self, x):
|
| 393 |
+
output = self.features(x)
|
| 394 |
+
return output
|
esrgan_plus/codes/models/modules/block.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
####################
|
| 8 |
+
# Basic blocks
|
| 9 |
+
####################
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
|
| 13 |
+
# helper selecting activation
|
| 14 |
+
# neg_slope: for leakyrelu and init of prelu
|
| 15 |
+
# n_prelu: for p_relu num_parameters
|
| 16 |
+
act_type = act_type.lower()
|
| 17 |
+
if act_type == 'relu':
|
| 18 |
+
layer = nn.ReLU(inplace)
|
| 19 |
+
elif act_type == 'leakyrelu':
|
| 20 |
+
layer = nn.LeakyReLU(neg_slope, inplace)
|
| 21 |
+
elif act_type == 'prelu':
|
| 22 |
+
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
|
| 23 |
+
else:
|
| 24 |
+
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
|
| 25 |
+
return layer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def norm(norm_type, nc):
|
| 29 |
+
# helper selecting normalization layer
|
| 30 |
+
norm_type = norm_type.lower()
|
| 31 |
+
if norm_type == 'batch':
|
| 32 |
+
layer = nn.BatchNorm2d(nc, affine=True)
|
| 33 |
+
elif norm_type == 'instance':
|
| 34 |
+
layer = nn.InstanceNorm2d(nc, affine=False)
|
| 35 |
+
else:
|
| 36 |
+
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
|
| 37 |
+
return layer
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def pad(pad_type, padding):
|
| 41 |
+
# helper selecting padding layer
|
| 42 |
+
# if padding is 'zero', do by conv layers
|
| 43 |
+
pad_type = pad_type.lower()
|
| 44 |
+
if padding == 0:
|
| 45 |
+
return None
|
| 46 |
+
if pad_type == 'reflect':
|
| 47 |
+
layer = nn.ReflectionPad2d(padding)
|
| 48 |
+
elif pad_type == 'replicate':
|
| 49 |
+
layer = nn.ReplicationPad2d(padding)
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
|
| 52 |
+
return layer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_valid_padding(kernel_size, dilation):
|
| 56 |
+
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
|
| 57 |
+
padding = (kernel_size - 1) // 2
|
| 58 |
+
return padding
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ConcatBlock(nn.Module):
|
| 62 |
+
# Concat the output of a submodule to its input
|
| 63 |
+
def __init__(self, submodule):
|
| 64 |
+
super(ConcatBlock, self).__init__()
|
| 65 |
+
self.sub = submodule
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
output = torch.cat((x, self.sub(x)), dim=1)
|
| 69 |
+
return output
|
| 70 |
+
|
| 71 |
+
def __repr__(self):
|
| 72 |
+
tmpstr = 'Identity .. \n|'
|
| 73 |
+
modstr = self.sub.__repr__().replace('\n', '\n|')
|
| 74 |
+
tmpstr = tmpstr + modstr
|
| 75 |
+
return tmpstr
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ShortcutBlock(nn.Module):
|
| 79 |
+
#Elementwise sum the output of a submodule to its input
|
| 80 |
+
def __init__(self, submodule):
|
| 81 |
+
super(ShortcutBlock, self).__init__()
|
| 82 |
+
self.sub = submodule
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
output = x + self.sub(x)
|
| 86 |
+
return output
|
| 87 |
+
|
| 88 |
+
def __repr__(self):
|
| 89 |
+
tmpstr = 'Identity + \n|'
|
| 90 |
+
modstr = self.sub.__repr__().replace('\n', '\n|')
|
| 91 |
+
tmpstr = tmpstr + modstr
|
| 92 |
+
return tmpstr
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def sequential(*args):
|
| 96 |
+
# Flatten Sequential. It unwraps nn.Sequential.
|
| 97 |
+
if len(args) == 1:
|
| 98 |
+
if isinstance(args[0], OrderedDict):
|
| 99 |
+
raise NotImplementedError('sequential does not support OrderedDict input.')
|
| 100 |
+
return args[0] # No sequential is needed.
|
| 101 |
+
modules = []
|
| 102 |
+
for module in args:
|
| 103 |
+
if isinstance(module, nn.Sequential):
|
| 104 |
+
for submodule in module.children():
|
| 105 |
+
modules.append(submodule)
|
| 106 |
+
elif isinstance(module, nn.Module):
|
| 107 |
+
modules.append(module)
|
| 108 |
+
return nn.Sequential(*modules)
|
| 109 |
+
|
| 110 |
+
class GaussianNoise(nn.Module):
|
| 111 |
+
def __init__(self, sigma=0.1, is_relative_detach=False):
|
| 112 |
+
super().__init__()
|
| 113 |
+
self.sigma = sigma
|
| 114 |
+
self.is_relative_detach = is_relative_detach
|
| 115 |
+
self.noise = torch.tensor(0, dtype=torch.float)
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
if self.training and self.sigma != 0:
|
| 119 |
+
scale = self.sigma * x.detach() if self.is_relative_detach else self.sigma * x
|
| 120 |
+
sampled_noise = self.noise.repeat(*x.size()).normal_() * scale
|
| 121 |
+
x = x + sampled_noise
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True, \
|
| 126 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
|
| 127 |
+
'''
|
| 128 |
+
Conv layer with padding, normalization, activation
|
| 129 |
+
mode: CNA --> Conv -> Norm -> Act
|
| 130 |
+
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
| 131 |
+
'''
|
| 132 |
+
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [{:s}]'.format(mode)
|
| 133 |
+
padding = get_valid_padding(kernel_size, dilation)
|
| 134 |
+
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
|
| 135 |
+
padding = padding if pad_type == 'zero' else 0
|
| 136 |
+
|
| 137 |
+
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
|
| 138 |
+
dilation=dilation, bias=bias, groups=groups)
|
| 139 |
+
a = act(act_type) if act_type else None
|
| 140 |
+
if 'CNA' in mode:
|
| 141 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
| 142 |
+
return sequential(p, c, n, a)
|
| 143 |
+
elif mode == 'NAC':
|
| 144 |
+
if norm_type is None and act_type is not None:
|
| 145 |
+
a = act(act_type, inplace=False)
|
| 146 |
+
# Important!
|
| 147 |
+
# input----ReLU(inplace)----Conv--+----output
|
| 148 |
+
# |________________________|
|
| 149 |
+
# inplace ReLU will modify the input, therefore wrong output
|
| 150 |
+
n = norm(norm_type, in_nc) if norm_type else None
|
| 151 |
+
return sequential(n, a, p, c)
|
| 152 |
+
|
| 153 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 154 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# https://github.com/github-pengge/PyTorch-progressive_growing_of_gans/blob/master/models/base_model.py
|
| 158 |
+
class minibatch_std_concat_layer(nn.Module):
|
| 159 |
+
def __init__(self, averaging='all'):
|
| 160 |
+
super(minibatch_std_concat_layer, self).__init__()
|
| 161 |
+
self.averaging = averaging.lower()
|
| 162 |
+
if 'group' in self.averaging:
|
| 163 |
+
self.n = int(self.averaging[5:])
|
| 164 |
+
else:
|
| 165 |
+
assert self.averaging in ['all', 'flat', 'spatial', 'none', 'gpool'], 'Invalid averaging mode'%self.averaging
|
| 166 |
+
self.adjusted_std = lambda x, **kwargs: torch.sqrt(torch.mean((x - torch.mean(x, **kwargs)) ** 2, **kwargs) + 1e-8)
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
shape = list(x.size())
|
| 170 |
+
target_shape = copy.deepcopy(shape)
|
| 171 |
+
vals = self.adjusted_std(x, dim=0, keepdim=True)
|
| 172 |
+
if self.averaging == 'all':
|
| 173 |
+
target_shape[1] = 1
|
| 174 |
+
vals = torch.mean(vals, dim=1, keepdim=True)
|
| 175 |
+
elif self.averaging == 'spatial':
|
| 176 |
+
if len(shape) == 4:
|
| 177 |
+
vals = mean(vals, axis=[2,3], keepdim=True) # torch.mean(torch.mean(vals, 2, keepdim=True), 3, keepdim=True)
|
| 178 |
+
elif self.averaging == 'none':
|
| 179 |
+
target_shape = [target_shape[0]] + [s for s in target_shape[1:]]
|
| 180 |
+
elif self.averaging == 'gpool':
|
| 181 |
+
if len(shape) == 4:
|
| 182 |
+
vals = mean(x, [0,2,3], keepdim=True) # torch.mean(torch.mean(torch.mean(x, 2, keepdim=True), 3, keepdim=True), 0, keepdim=True)
|
| 183 |
+
elif self.averaging == 'flat':
|
| 184 |
+
target_shape[1] = 1
|
| 185 |
+
vals = torch.FloatTensor([self.adjusted_std(x)])
|
| 186 |
+
else: # self.averaging == 'group'
|
| 187 |
+
target_shape[1] = self.n
|
| 188 |
+
vals = vals.view(self.n, self.shape[1]/self.n, self.shape[2], self.shape[3])
|
| 189 |
+
vals = mean(vals, axis=0, keepdim=True).view(1, self.n, 1, 1)
|
| 190 |
+
vals = vals.expand(*target_shape)
|
| 191 |
+
return torch.cat([x, vals], 1)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
####################
|
| 195 |
+
# Useful blocks
|
| 196 |
+
####################
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class ResNetBlock(nn.Module):
|
| 200 |
+
'''
|
| 201 |
+
ResNet Block, 3-3 style
|
| 202 |
+
with extra residual scaling used in EDSR
|
| 203 |
+
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
|
| 204 |
+
'''
|
| 205 |
+
|
| 206 |
+
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
|
| 207 |
+
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
|
| 208 |
+
super(ResNetBlock, self).__init__()
|
| 209 |
+
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
| 210 |
+
norm_type, act_type, mode)
|
| 211 |
+
if mode == 'CNA':
|
| 212 |
+
act_type = None
|
| 213 |
+
if mode == 'CNAC': # Residual path: |-CNAC-|
|
| 214 |
+
act_type = None
|
| 215 |
+
norm_type = None
|
| 216 |
+
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
|
| 217 |
+
norm_type, act_type, mode)
|
| 218 |
+
# if in_nc != out_nc:
|
| 219 |
+
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
|
| 220 |
+
# None, None)
|
| 221 |
+
# print('Need a projecter in ResNetBlock.')
|
| 222 |
+
# else:
|
| 223 |
+
# self.project = lambda x:x
|
| 224 |
+
self.res = sequential(conv0, conv1)
|
| 225 |
+
self.res_scale = res_scale
|
| 226 |
+
|
| 227 |
+
def forward(self, x):
|
| 228 |
+
res = self.res(x).mul(self.res_scale)
|
| 229 |
+
return x + res
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class ResidualDenseBlock_5C(nn.Module):
|
| 233 |
+
'''
|
| 234 |
+
Residual Dense Block
|
| 235 |
+
style: 5 convs
|
| 236 |
+
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
|
| 237 |
+
'''
|
| 238 |
+
|
| 239 |
+
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
| 240 |
+
norm_type=None, act_type='leakyrelu', mode='CNA', gaussian_noise=True):
|
| 241 |
+
super(ResidualDenseBlock_5C, self).__init__()
|
| 242 |
+
# gc: growth channel, i.e. intermediate channels
|
| 243 |
+
self.noise = GaussianNoise() if gaussian_noise else None
|
| 244 |
+
self.conv1x1 = conv1x1(nc, gc)
|
| 245 |
+
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
| 246 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
| 247 |
+
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
| 248 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
| 249 |
+
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
| 250 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
| 251 |
+
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
|
| 252 |
+
norm_type=norm_type, act_type=act_type, mode=mode)
|
| 253 |
+
if mode == 'CNA':
|
| 254 |
+
last_act = None
|
| 255 |
+
else:
|
| 256 |
+
last_act = act_type
|
| 257 |
+
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
|
| 258 |
+
norm_type=norm_type, act_type=last_act, mode=mode)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
x1 = self.conv1(x)
|
| 262 |
+
x2 = self.conv2(torch.cat((x, x1), 1))
|
| 263 |
+
x2 = x2 + self.conv1x1(x)
|
| 264 |
+
x3 = self.conv3(torch.cat((x, x1, x2), 1))
|
| 265 |
+
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
|
| 266 |
+
x4 = x4 + x2
|
| 267 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
| 268 |
+
return self.noise(x5.mul(0.2) + x)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class RRDB(nn.Module):
|
| 272 |
+
'''
|
| 273 |
+
Residual in Residual Dense Block
|
| 274 |
+
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
|
| 275 |
+
'''
|
| 276 |
+
|
| 277 |
+
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
| 278 |
+
norm_type=None, act_type='leakyrelu', mode='CNA'):
|
| 279 |
+
super(RRDB, self).__init__()
|
| 280 |
+
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
| 281 |
+
norm_type, act_type, mode)
|
| 282 |
+
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
| 283 |
+
norm_type, act_type, mode)
|
| 284 |
+
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
|
| 285 |
+
norm_type, act_type, mode)
|
| 286 |
+
|
| 287 |
+
def forward(self, x):
|
| 288 |
+
out = self.RDB1(x)
|
| 289 |
+
out = self.RDB2(out)
|
| 290 |
+
out = self.RDB3(out)
|
| 291 |
+
return out.mul(0.2) + x
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
####################
|
| 295 |
+
# Upsampler
|
| 296 |
+
####################
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
| 300 |
+
pad_type='zero', norm_type=None, act_type='relu'):
|
| 301 |
+
'''
|
| 302 |
+
Pixel shuffle layer
|
| 303 |
+
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
|
| 304 |
+
Neural Network, CVPR17)
|
| 305 |
+
'''
|
| 306 |
+
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias, \
|
| 307 |
+
pad_type=pad_type, norm_type=None, act_type=None)
|
| 308 |
+
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
| 309 |
+
|
| 310 |
+
n = norm(norm_type, out_nc) if norm_type else None
|
| 311 |
+
a = act(act_type) if act_type else None
|
| 312 |
+
return sequential(conv, pixel_shuffle, n, a)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True, \
|
| 316 |
+
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
|
| 317 |
+
# Up conv
|
| 318 |
+
# described in https://distill.pub/2016/deconv-checkerboard/
|
| 319 |
+
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
|
| 320 |
+
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias, \
|
| 321 |
+
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
|
| 322 |
+
return sequential(upsample, conv)
|
esrgan_plus/codes/models/modules/loss.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Define GAN loss: [vanilla | lsgan | wgan-gp]
|
| 6 |
+
class GANLoss(nn.Module):
|
| 7 |
+
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
|
| 8 |
+
super(GANLoss, self).__init__()
|
| 9 |
+
self.gan_type = gan_type.lower()
|
| 10 |
+
self.real_label_val = real_label_val
|
| 11 |
+
self.fake_label_val = fake_label_val
|
| 12 |
+
|
| 13 |
+
if self.gan_type == 'vanilla':
|
| 14 |
+
self.loss = nn.BCEWithLogitsLoss()
|
| 15 |
+
elif self.gan_type == 'lsgan':
|
| 16 |
+
self.loss = nn.MSELoss()
|
| 17 |
+
elif self.gan_type == 'wgan-gp':
|
| 18 |
+
|
| 19 |
+
def wgan_loss(input, target):
|
| 20 |
+
# target is boolean
|
| 21 |
+
return -1 * input.mean() if target else input.mean()
|
| 22 |
+
|
| 23 |
+
self.loss = wgan_loss
|
| 24 |
+
else:
|
| 25 |
+
raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
|
| 26 |
+
|
| 27 |
+
def get_target_label(self, input, target_is_real):
|
| 28 |
+
if self.gan_type == 'wgan-gp':
|
| 29 |
+
return target_is_real
|
| 30 |
+
if target_is_real:
|
| 31 |
+
return torch.empty_like(input).fill_(self.real_label_val)
|
| 32 |
+
else:
|
| 33 |
+
return torch.empty_like(input).fill_(self.fake_label_val)
|
| 34 |
+
|
| 35 |
+
def forward(self, input, target_is_real):
|
| 36 |
+
target_label = self.get_target_label(input, target_is_real)
|
| 37 |
+
loss = self.loss(input, target_label)
|
| 38 |
+
return loss
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GradientPenaltyLoss(nn.Module):
|
| 42 |
+
def __init__(self, device=torch.device('cpu')):
|
| 43 |
+
super(GradientPenaltyLoss, self).__init__()
|
| 44 |
+
self.register_buffer('grad_outputs', torch.Tensor())
|
| 45 |
+
self.grad_outputs = self.grad_outputs.to(device)
|
| 46 |
+
|
| 47 |
+
def get_grad_outputs(self, input):
|
| 48 |
+
if self.grad_outputs.size() != input.size():
|
| 49 |
+
self.grad_outputs.resize_(input.size()).fill_(1.0)
|
| 50 |
+
return self.grad_outputs
|
| 51 |
+
|
| 52 |
+
def forward(self, interp, interp_crit):
|
| 53 |
+
grad_outputs = self.get_grad_outputs(interp_crit)
|
| 54 |
+
grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, \
|
| 55 |
+
grad_outputs=grad_outputs, create_graph=True, retain_graph=True, only_inputs=True)[0]
|
| 56 |
+
grad_interp = grad_interp.view(grad_interp.size(0), -1)
|
| 57 |
+
grad_interp_norm = grad_interp.norm(2, dim=1)
|
| 58 |
+
|
| 59 |
+
loss = ((grad_interp_norm - 1)**2).mean()
|
| 60 |
+
return loss
|
esrgan_plus/codes/models/modules/seg_arch.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
architecture for segmentation
|
| 3 |
+
'''
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from . import block as B
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Res131(nn.Module):
|
| 9 |
+
def __init__(self, in_nc, mid_nc, out_nc, dilation=1, stride=1):
|
| 10 |
+
super(Res131, self).__init__()
|
| 11 |
+
conv0 = B.conv_block(in_nc, mid_nc, 1, 1, 1, 1, False, 'zero', 'batch')
|
| 12 |
+
conv1 = B.conv_block(mid_nc, mid_nc, 3, stride, dilation, 1, False, 'zero', 'batch')
|
| 13 |
+
conv2 = B.conv_block(mid_nc, out_nc, 1, 1, 1, 1, False, 'zero', 'batch', None) # No ReLU
|
| 14 |
+
self.res = B.sequential(conv0, conv1, conv2)
|
| 15 |
+
if in_nc == out_nc:
|
| 16 |
+
self.has_proj = False
|
| 17 |
+
else:
|
| 18 |
+
self.has_proj = True
|
| 19 |
+
self.proj = B.conv_block(in_nc, out_nc, 1, stride, 1, 1, False, 'zero', 'batch', None)
|
| 20 |
+
# No ReLU
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
res = self.res(x)
|
| 24 |
+
if self.has_proj:
|
| 25 |
+
x = self.proj(x)
|
| 26 |
+
return nn.functional.relu(x + res, inplace=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class OutdoorSceneSeg(nn.Module):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super(OutdoorSceneSeg, self).__init__()
|
| 32 |
+
# conv1
|
| 33 |
+
blocks = []
|
| 34 |
+
conv1_1 = B.conv_block(3, 64, 3, 2, 1, 1, False, 'zero', 'batch') # /2
|
| 35 |
+
conv1_2 = B.conv_block(64, 64, 3, 1, 1, 1, False, 'zero', 'batch')
|
| 36 |
+
conv1_3 = B.conv_block(64, 128, 3, 1, 1, 1, False, 'zero', 'batch')
|
| 37 |
+
max_pool = nn.MaxPool2d(3, stride=2, padding=0, ceil_mode=True) # /2
|
| 38 |
+
blocks = [conv1_1, conv1_2, conv1_3, max_pool]
|
| 39 |
+
# conv2, 3 blocks
|
| 40 |
+
blocks.append(Res131(128, 64, 256))
|
| 41 |
+
for i in range(2):
|
| 42 |
+
blocks.append(Res131(256, 64, 256))
|
| 43 |
+
# conv3, 4 blocks
|
| 44 |
+
blocks.append(Res131(256, 128, 512, 1, 2)) # /2
|
| 45 |
+
for i in range(3):
|
| 46 |
+
blocks.append(Res131(512, 128, 512))
|
| 47 |
+
# conv4, 23 blocks
|
| 48 |
+
blocks.append(Res131(512, 256, 1024, 2))
|
| 49 |
+
for i in range(22):
|
| 50 |
+
blocks.append(Res131(1024, 256, 1024, 2))
|
| 51 |
+
# conv5
|
| 52 |
+
blocks.append(Res131(1024, 512, 2048, 4))
|
| 53 |
+
blocks.append(Res131(2048, 512, 2048, 4))
|
| 54 |
+
blocks.append(Res131(2048, 512, 2048, 4))
|
| 55 |
+
blocks.append(B.conv_block(2048, 512, 3, 1, 1, 1, False, 'zero', 'batch'))
|
| 56 |
+
blocks.append(nn.Dropout(0.1))
|
| 57 |
+
# # conv6
|
| 58 |
+
blocks.append(nn.Conv2d(512, 8, 1, 1))
|
| 59 |
+
|
| 60 |
+
self.feature = B.sequential(*blocks)
|
| 61 |
+
# deconv
|
| 62 |
+
self.deconv = nn.ConvTranspose2d(8, 8, 16, 8, 4, 0, 8, False, 1)
|
| 63 |
+
# softmax
|
| 64 |
+
self.softmax = nn.Softmax(1)
|
| 65 |
+
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
x = self.feature(x)
|
| 68 |
+
x = self.deconv(x)
|
| 69 |
+
x = self.softmax(x)
|
| 70 |
+
return x
|
esrgan_plus/codes/models/modules/sft_arch.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
architecture for sft
|
| 3 |
+
'''
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SFTLayer(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super(SFTLayer, self).__init__()
|
| 11 |
+
self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
|
| 12 |
+
self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
|
| 13 |
+
self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
|
| 14 |
+
self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
# x[0]: fea; x[1]: cond
|
| 18 |
+
scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True))
|
| 19 |
+
shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True))
|
| 20 |
+
return x[0] * (scale + 1) + shift
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResBlock_SFT(nn.Module):
|
| 24 |
+
def __init__(self):
|
| 25 |
+
super(ResBlock_SFT, self).__init__()
|
| 26 |
+
self.sft0 = SFTLayer()
|
| 27 |
+
self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
|
| 28 |
+
self.sft1 = SFTLayer()
|
| 29 |
+
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# x[0]: fea; x[1]: cond
|
| 33 |
+
fea = self.sft0(x)
|
| 34 |
+
fea = F.relu(self.conv0(fea), inplace=True)
|
| 35 |
+
fea = self.sft1((fea, x[1]))
|
| 36 |
+
fea = self.conv1(fea)
|
| 37 |
+
return (x[0] + fea, x[1]) # return a tuple containing features and conditions
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SFT_Net(nn.Module):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super(SFT_Net, self).__init__()
|
| 43 |
+
self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
|
| 44 |
+
|
| 45 |
+
sft_branch = []
|
| 46 |
+
for i in range(16):
|
| 47 |
+
sft_branch.append(ResBlock_SFT())
|
| 48 |
+
sft_branch.append(SFTLayer())
|
| 49 |
+
sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
|
| 50 |
+
self.sft_branch = nn.Sequential(*sft_branch)
|
| 51 |
+
|
| 52 |
+
self.HR_branch = nn.Sequential(
|
| 53 |
+
nn.Conv2d(64, 256, 3, 1, 1),
|
| 54 |
+
nn.PixelShuffle(2),
|
| 55 |
+
nn.ReLU(True),
|
| 56 |
+
nn.Conv2d(64, 256, 3, 1, 1),
|
| 57 |
+
nn.PixelShuffle(2),
|
| 58 |
+
nn.ReLU(True),
|
| 59 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 60 |
+
nn.ReLU(True),
|
| 61 |
+
nn.Conv2d(64, 3, 3, 1, 1)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.CondNet = nn.Sequential(
|
| 65 |
+
nn.Conv2d(8, 128, 4, 4),
|
| 66 |
+
nn.LeakyReLU(0.1, True),
|
| 67 |
+
nn.Conv2d(128, 128, 1),
|
| 68 |
+
nn.LeakyReLU(0.1, True),
|
| 69 |
+
nn.Conv2d(128, 128, 1),
|
| 70 |
+
nn.LeakyReLU(0.1, True),
|
| 71 |
+
nn.Conv2d(128, 128, 1),
|
| 72 |
+
nn.LeakyReLU(0.1, True),
|
| 73 |
+
nn.Conv2d(128, 32, 1)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
# x[0]: img; x[1]: seg
|
| 78 |
+
cond = self.CondNet(x[1])
|
| 79 |
+
fea = self.conv0(x[0])
|
| 80 |
+
res = self.sft_branch((fea, cond))
|
| 81 |
+
fea = fea + res
|
| 82 |
+
out = self.HR_branch(fea)
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Auxiliary Classifier Discriminator
|
| 87 |
+
class ACD_VGG_BN_96(nn.Module):
|
| 88 |
+
def __init__(self):
|
| 89 |
+
super(ACD_VGG_BN_96, self).__init__()
|
| 90 |
+
|
| 91 |
+
self.feature = nn.Sequential(
|
| 92 |
+
nn.Conv2d(3, 64, 3, 1, 1),
|
| 93 |
+
nn.LeakyReLU(0.1, True),
|
| 94 |
+
|
| 95 |
+
nn.Conv2d(64, 64, 4, 2, 1),
|
| 96 |
+
nn.BatchNorm2d(64, affine=True),
|
| 97 |
+
nn.LeakyReLU(0.1, True),
|
| 98 |
+
|
| 99 |
+
nn.Conv2d(64, 128, 3, 1, 1),
|
| 100 |
+
nn.BatchNorm2d(128, affine=True),
|
| 101 |
+
nn.LeakyReLU(0.1, True),
|
| 102 |
+
|
| 103 |
+
nn.Conv2d(128, 128, 4, 2, 1),
|
| 104 |
+
nn.BatchNorm2d(128, affine=True),
|
| 105 |
+
nn.LeakyReLU(0.1, True),
|
| 106 |
+
|
| 107 |
+
nn.Conv2d(128, 256, 3, 1, 1),
|
| 108 |
+
nn.BatchNorm2d(256, affine=True),
|
| 109 |
+
nn.LeakyReLU(0.1, True),
|
| 110 |
+
|
| 111 |
+
nn.Conv2d(256, 256, 4, 2, 1),
|
| 112 |
+
nn.BatchNorm2d(256, affine=True),
|
| 113 |
+
nn.LeakyReLU(0.1, True),
|
| 114 |
+
|
| 115 |
+
nn.Conv2d(256, 512, 3, 1, 1),
|
| 116 |
+
nn.BatchNorm2d(512, affine=True),
|
| 117 |
+
nn.LeakyReLU(0.1, True),
|
| 118 |
+
|
| 119 |
+
nn.Conv2d(512, 512, 4, 2, 1),
|
| 120 |
+
nn.BatchNorm2d(512, affine=True),
|
| 121 |
+
nn.LeakyReLU(0.1, True),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# gan
|
| 125 |
+
self.gan = nn.Sequential(
|
| 126 |
+
nn.Linear(512*6*6, 100),
|
| 127 |
+
nn.LeakyReLU(0.1, True),
|
| 128 |
+
nn.Linear(100, 1)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
self.cls = nn.Sequential(
|
| 132 |
+
nn.Linear(512*6*6, 100),
|
| 133 |
+
nn.LeakyReLU(0.1, True),
|
| 134 |
+
nn.Linear(100, 8)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def forward(self, x):
|
| 138 |
+
fea = self.feature(x)
|
| 139 |
+
fea = fea.view(fea.size(0), -1)
|
| 140 |
+
gan = self.gan(fea)
|
| 141 |
+
cls = self.cls(fea)
|
| 142 |
+
return [gan, cls]
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
#############################################
|
| 146 |
+
# below is the sft arch for the torch version
|
| 147 |
+
#############################################
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class SFTLayer_torch(nn.Module):
|
| 151 |
+
def __init__(self):
|
| 152 |
+
super(SFTLayer_torch, self).__init__()
|
| 153 |
+
self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1)
|
| 154 |
+
self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1)
|
| 155 |
+
self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1)
|
| 156 |
+
self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1)
|
| 157 |
+
|
| 158 |
+
def forward(self, x):
|
| 159 |
+
# x[0]: fea; x[1]: cond
|
| 160 |
+
scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.01, inplace=True))
|
| 161 |
+
shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.01, inplace=True))
|
| 162 |
+
return x[0] * scale + shift
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ResBlock_SFT_torch(nn.Module):
|
| 166 |
+
def __init__(self):
|
| 167 |
+
super(ResBlock_SFT_torch, self).__init__()
|
| 168 |
+
self.sft0 = SFTLayer_torch()
|
| 169 |
+
self.conv0 = nn.Conv2d(64, 64, 3, 1, 1)
|
| 170 |
+
self.sft1 = SFTLayer_torch()
|
| 171 |
+
self.conv1 = nn.Conv2d(64, 64, 3, 1, 1)
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
# x[0]: fea; x[1]: cond
|
| 175 |
+
fea = F.relu(self.sft0(x), inplace=True)
|
| 176 |
+
fea = self.conv0(fea)
|
| 177 |
+
fea = F.relu(self.sft1((fea, x[1])), inplace=True)
|
| 178 |
+
fea = self.conv1(fea)
|
| 179 |
+
return (x[0] + fea, x[1]) # return a tuple containing features and conditions
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class SFT_Net_torch(nn.Module):
|
| 183 |
+
def __init__(self):
|
| 184 |
+
super(SFT_Net_torch, self).__init__()
|
| 185 |
+
self.conv0 = nn.Conv2d(3, 64, 3, 1, 1)
|
| 186 |
+
|
| 187 |
+
sft_branch = []
|
| 188 |
+
for i in range(16):
|
| 189 |
+
sft_branch.append(ResBlock_SFT_torch())
|
| 190 |
+
sft_branch.append(SFTLayer_torch())
|
| 191 |
+
sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1))
|
| 192 |
+
self.sft_branch = nn.Sequential(*sft_branch)
|
| 193 |
+
|
| 194 |
+
self.HR_branch = nn.Sequential(
|
| 195 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 196 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 197 |
+
nn.ReLU(True),
|
| 198 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
| 199 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 200 |
+
nn.ReLU(True),
|
| 201 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
| 202 |
+
nn.ReLU(True),
|
| 203 |
+
nn.Conv2d(64, 3, 3, 1, 1)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Condtion network
|
| 207 |
+
self.CondNet = nn.Sequential(
|
| 208 |
+
nn.Conv2d(8, 128, 4, 4),
|
| 209 |
+
nn.LeakyReLU(0.1, True),
|
| 210 |
+
nn.Conv2d(128, 128, 1),
|
| 211 |
+
nn.LeakyReLU(0.1, True),
|
| 212 |
+
nn.Conv2d(128, 128, 1),
|
| 213 |
+
nn.LeakyReLU(0.1, True),
|
| 214 |
+
nn.Conv2d(128, 128, 1),
|
| 215 |
+
nn.LeakyReLU(0.1, True),
|
| 216 |
+
nn.Conv2d(128, 32, 1)
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def forward(self, x):
|
| 220 |
+
# x[0]: img; x[1]: seg
|
| 221 |
+
cond = self.CondNet(x[1])
|
| 222 |
+
fea = self.conv0(x[0])
|
| 223 |
+
res = self.sft_branch((fea, cond))
|
| 224 |
+
fea = fea + res
|
| 225 |
+
out = self.HR_branch(fea)
|
| 226 |
+
return out
|
esrgan_plus/codes/models/modules/spectral_norm.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
Copy from pytorch github repo
|
| 3 |
+
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
| 4 |
+
'''
|
| 5 |
+
import torch
|
| 6 |
+
from torch.nn.functional import normalize
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SpectralNorm(object):
|
| 11 |
+
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
| 12 |
+
self.name = name
|
| 13 |
+
self.dim = dim
|
| 14 |
+
if n_power_iterations <= 0:
|
| 15 |
+
raise ValueError('Expected n_power_iterations to be positive, but '
|
| 16 |
+
'got n_power_iterations={}'.format(n_power_iterations))
|
| 17 |
+
self.n_power_iterations = n_power_iterations
|
| 18 |
+
self.eps = eps
|
| 19 |
+
|
| 20 |
+
def compute_weight(self, module):
|
| 21 |
+
weight = getattr(module, self.name + '_orig')
|
| 22 |
+
u = getattr(module, self.name + '_u')
|
| 23 |
+
weight_mat = weight
|
| 24 |
+
if self.dim != 0:
|
| 25 |
+
# permute dim to front
|
| 26 |
+
weight_mat = weight_mat.permute(self.dim,
|
| 27 |
+
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
| 28 |
+
height = weight_mat.size(0)
|
| 29 |
+
weight_mat = weight_mat.reshape(height, -1)
|
| 30 |
+
with torch.no_grad():
|
| 31 |
+
for _ in range(self.n_power_iterations):
|
| 32 |
+
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
| 33 |
+
# are the first left and right singular vectors.
|
| 34 |
+
# This power iteration produces approximations of `u` and `v`.
|
| 35 |
+
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
|
| 36 |
+
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
|
| 37 |
+
|
| 38 |
+
sigma = torch.dot(u, torch.matmul(weight_mat, v))
|
| 39 |
+
weight = weight / sigma
|
| 40 |
+
return weight, u
|
| 41 |
+
|
| 42 |
+
def remove(self, module):
|
| 43 |
+
weight = getattr(module, self.name)
|
| 44 |
+
delattr(module, self.name)
|
| 45 |
+
delattr(module, self.name + '_u')
|
| 46 |
+
delattr(module, self.name + '_orig')
|
| 47 |
+
module.register_parameter(self.name, torch.nn.Parameter(weight))
|
| 48 |
+
|
| 49 |
+
def __call__(self, module, inputs):
|
| 50 |
+
if module.training:
|
| 51 |
+
weight, u = self.compute_weight(module)
|
| 52 |
+
setattr(module, self.name, weight)
|
| 53 |
+
setattr(module, self.name + '_u', u)
|
| 54 |
+
else:
|
| 55 |
+
r_g = getattr(module, self.name + '_orig').requires_grad
|
| 56 |
+
getattr(module, self.name).detach_().requires_grad_(r_g)
|
| 57 |
+
|
| 58 |
+
@staticmethod
|
| 59 |
+
def apply(module, name, n_power_iterations, dim, eps):
|
| 60 |
+
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
| 61 |
+
weight = module._parameters[name]
|
| 62 |
+
height = weight.size(dim)
|
| 63 |
+
|
| 64 |
+
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
|
| 65 |
+
delattr(module, fn.name)
|
| 66 |
+
module.register_parameter(fn.name + "_orig", weight)
|
| 67 |
+
# We still need to assign weight back as fn.name because all sorts of
|
| 68 |
+
# things may assume that it exists, e.g., when initializing weights.
|
| 69 |
+
# However, we can't directly assign as it could be an nn.Parameter and
|
| 70 |
+
# gets added as a parameter. Instead, we register weight.data as a
|
| 71 |
+
# buffer, which will cause weight to be included in the state dict
|
| 72 |
+
# and also supports nn.init due to shared storage.
|
| 73 |
+
module.register_buffer(fn.name, weight.data)
|
| 74 |
+
module.register_buffer(fn.name + "_u", u)
|
| 75 |
+
|
| 76 |
+
module.register_forward_pre_hook(fn)
|
| 77 |
+
return fn
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
|
| 81 |
+
r"""Applies spectral normalization to a parameter in the given module.
|
| 82 |
+
|
| 83 |
+
.. math::
|
| 84 |
+
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
|
| 85 |
+
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
| 86 |
+
|
| 87 |
+
Spectral normalization stabilizes the training of discriminators (critics)
|
| 88 |
+
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
|
| 89 |
+
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
| 90 |
+
power iteration method. If the dimension of the weight tensor is greater
|
| 91 |
+
than 2, it is reshaped to 2D in power iteration method to get spectral
|
| 92 |
+
norm. This is implemented via a hook that calculates spectral norm and
|
| 93 |
+
rescales weight before every :meth:`~Module.forward` call.
|
| 94 |
+
|
| 95 |
+
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
| 96 |
+
|
| 97 |
+
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
module (nn.Module): containing module
|
| 101 |
+
name (str, optional): name of weight parameter
|
| 102 |
+
n_power_iterations (int, optional): number of power iterations to
|
| 103 |
+
calculate spectal norm
|
| 104 |
+
eps (float, optional): epsilon for numerical stability in
|
| 105 |
+
calculating norms
|
| 106 |
+
dim (int, optional): dimension corresponding to number of outputs,
|
| 107 |
+
the default is 0, except for modules that are instances of
|
| 108 |
+
ConvTranspose1/2/3d, when it is 1
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
The original module with the spectal norm hook
|
| 112 |
+
|
| 113 |
+
Example::
|
| 114 |
+
|
| 115 |
+
>>> m = spectral_norm(nn.Linear(20, 40))
|
| 116 |
+
Linear (20 -> 40)
|
| 117 |
+
>>> m.weight_u.size()
|
| 118 |
+
torch.Size([20])
|
| 119 |
+
|
| 120 |
+
"""
|
| 121 |
+
if dim is None:
|
| 122 |
+
if isinstance(
|
| 123 |
+
module,
|
| 124 |
+
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)):
|
| 125 |
+
dim = 1
|
| 126 |
+
else:
|
| 127 |
+
dim = 0
|
| 128 |
+
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
| 129 |
+
return module
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def remove_spectral_norm(module, name='weight'):
|
| 133 |
+
r"""Removes the spectral normalization reparameterization from a module.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
module (nn.Module): containing module
|
| 137 |
+
name (str, optional): name of weight parameter
|
| 138 |
+
|
| 139 |
+
Example:
|
| 140 |
+
>>> m = spectral_norm(nn.Linear(40, 10))
|
| 141 |
+
>>> remove_spectral_norm(m)
|
| 142 |
+
"""
|
| 143 |
+
for k, hook in module._forward_pre_hooks.items():
|
| 144 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
| 145 |
+
hook.remove(module)
|
| 146 |
+
del module._forward_pre_hooks[k]
|
| 147 |
+
return module
|
| 148 |
+
|
| 149 |
+
raise ValueError("spectral_norm of '{}' not found in {}".format(name, module))
|
esrgan_plus/codes/models/networks.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.nn import init
|
| 6 |
+
|
| 7 |
+
import models.modules.architecture as arch
|
| 8 |
+
import models.modules.sft_arch as sft_arch
|
| 9 |
+
logger = logging.getLogger('base')
|
| 10 |
+
####################
|
| 11 |
+
# initialize
|
| 12 |
+
####################
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def weights_init_normal(m, std=0.02):
|
| 16 |
+
classname = m.__class__.__name__
|
| 17 |
+
if classname.find('Conv') != -1:
|
| 18 |
+
init.normal_(m.weight.data, 0.0, std)
|
| 19 |
+
if m.bias is not None:
|
| 20 |
+
m.bias.data.zero_()
|
| 21 |
+
elif classname.find('Linear') != -1:
|
| 22 |
+
init.normal_(m.weight.data, 0.0, std)
|
| 23 |
+
if m.bias is not None:
|
| 24 |
+
m.bias.data.zero_()
|
| 25 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 26 |
+
init.normal_(m.weight.data, 1.0, std) # BN also uses norm
|
| 27 |
+
init.constant_(m.bias.data, 0.0)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def weights_init_kaiming(m, scale=1):
|
| 31 |
+
classname = m.__class__.__name__
|
| 32 |
+
if classname.find('Conv') != -1:
|
| 33 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 34 |
+
m.weight.data *= scale
|
| 35 |
+
if m.bias is not None:
|
| 36 |
+
m.bias.data.zero_()
|
| 37 |
+
elif classname.find('Linear') != -1:
|
| 38 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
| 39 |
+
m.weight.data *= scale
|
| 40 |
+
if m.bias is not None:
|
| 41 |
+
m.bias.data.zero_()
|
| 42 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 43 |
+
init.constant_(m.weight.data, 1.0)
|
| 44 |
+
init.constant_(m.bias.data, 0.0)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def weights_init_orthogonal(m):
|
| 48 |
+
classname = m.__class__.__name__
|
| 49 |
+
if classname.find('Conv') != -1:
|
| 50 |
+
init.orthogonal_(m.weight.data, gain=1)
|
| 51 |
+
if m.bias is not None:
|
| 52 |
+
m.bias.data.zero_()
|
| 53 |
+
elif classname.find('Linear') != -1:
|
| 54 |
+
init.orthogonal_(m.weight.data, gain=1)
|
| 55 |
+
if m.bias is not None:
|
| 56 |
+
m.bias.data.zero_()
|
| 57 |
+
elif classname.find('BatchNorm2d') != -1:
|
| 58 |
+
init.constant_(m.weight.data, 1.0)
|
| 59 |
+
init.constant_(m.bias.data, 0.0)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def init_weights(net, init_type='kaiming', scale=1, std=0.02):
|
| 63 |
+
# scale for 'kaiming', std for 'normal'.
|
| 64 |
+
logger.info('Initialization method [{:s}]'.format(init_type))
|
| 65 |
+
if init_type == 'normal':
|
| 66 |
+
weights_init_normal_ = functools.partial(weights_init_normal, std=std)
|
| 67 |
+
net.apply(weights_init_normal_)
|
| 68 |
+
elif init_type == 'kaiming':
|
| 69 |
+
weights_init_kaiming_ = functools.partial(weights_init_kaiming, scale=scale)
|
| 70 |
+
net.apply(weights_init_kaiming_)
|
| 71 |
+
elif init_type == 'orthogonal':
|
| 72 |
+
net.apply(weights_init_orthogonal)
|
| 73 |
+
else:
|
| 74 |
+
raise NotImplementedError('initialization method [{:s}] not implemented'.format(init_type))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
####################
|
| 78 |
+
# define network
|
| 79 |
+
####################
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Generator
|
| 83 |
+
def define_G(opt):
|
| 84 |
+
gpu_ids = opt['gpu_ids']
|
| 85 |
+
opt_net = opt['network_G']
|
| 86 |
+
which_model = opt_net['which_model_G']
|
| 87 |
+
|
| 88 |
+
if which_model == 'sr_resnet': # SRResNet
|
| 89 |
+
netG = arch.SRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'], \
|
| 90 |
+
nb=opt_net['nb'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'], \
|
| 91 |
+
act_type='relu', mode=opt_net['mode'], upsample_mode='pixelshuffle')
|
| 92 |
+
|
| 93 |
+
elif which_model == 'sft_arch': # SFT-GAN
|
| 94 |
+
netG = sft_arch.SFT_Net()
|
| 95 |
+
|
| 96 |
+
elif which_model == 'RRDB_net': # RRDB
|
| 97 |
+
netG = arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], nf=opt_net['nf'],
|
| 98 |
+
nb=opt_net['nb'], gc=opt_net['gc'], upscale=opt_net['scale'], norm_type=opt_net['norm_type'],
|
| 99 |
+
act_type='leakyrelu', mode=opt_net['mode'], upsample_mode='upconv')
|
| 100 |
+
else:
|
| 101 |
+
raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
|
| 102 |
+
|
| 103 |
+
if opt['is_train']:
|
| 104 |
+
init_weights(netG, init_type='kaiming', scale=0.1)
|
| 105 |
+
if gpu_ids:
|
| 106 |
+
assert torch.cuda.is_available()
|
| 107 |
+
netG = nn.DataParallel(netG)
|
| 108 |
+
return netG
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# Discriminator
|
| 112 |
+
def define_D(opt):
|
| 113 |
+
gpu_ids = opt['gpu_ids']
|
| 114 |
+
opt_net = opt['network_D']
|
| 115 |
+
which_model = opt_net['which_model_D']
|
| 116 |
+
|
| 117 |
+
if which_model == 'discriminator_vgg_128':
|
| 118 |
+
netD = arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
| 119 |
+
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
| 120 |
+
|
| 121 |
+
elif which_model == 'dis_acd': # sft-gan, Auxiliary Classifier Discriminator
|
| 122 |
+
netD = sft_arch.ACD_VGG_BN_96()
|
| 123 |
+
|
| 124 |
+
elif which_model == 'discriminator_vgg_96':
|
| 125 |
+
netD = arch.Discriminator_VGG_96(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
| 126 |
+
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
| 127 |
+
elif which_model == 'discriminator_vgg_192':
|
| 128 |
+
netD = arch.Discriminator_VGG_192(in_nc=opt_net['in_nc'], base_nf=opt_net['nf'], \
|
| 129 |
+
norm_type=opt_net['norm_type'], mode=opt_net['mode'], act_type=opt_net['act_type'])
|
| 130 |
+
elif which_model == 'discriminator_vgg_128_SN':
|
| 131 |
+
netD = arch.Discriminator_VGG_128_SN()
|
| 132 |
+
else:
|
| 133 |
+
raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
|
| 134 |
+
|
| 135 |
+
init_weights(netD, init_type='kaiming', scale=1)
|
| 136 |
+
if gpu_ids:
|
| 137 |
+
netD = nn.DataParallel(netD)
|
| 138 |
+
return netD
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def define_F(opt, use_bn=False):
|
| 142 |
+
gpu_ids = opt['gpu_ids']
|
| 143 |
+
device = torch.device('cuda' if gpu_ids else 'cpu')
|
| 144 |
+
# pytorch pretrained VGG19-54, before ReLU.
|
| 145 |
+
if use_bn:
|
| 146 |
+
feature_layer = 49
|
| 147 |
+
else:
|
| 148 |
+
feature_layer = 34
|
| 149 |
+
netF = arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, \
|
| 150 |
+
use_input_norm=True, device=device)
|
| 151 |
+
# netF = arch.ResNet101FeatureExtractor(use_input_norm=True, device=device)
|
| 152 |
+
if gpu_ids:
|
| 153 |
+
netF = nn.DataParallel(netF)
|
| 154 |
+
netF.eval() # No need to train
|
| 155 |
+
return netF
|
esrgan_plus/codes/options/options.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import logging
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def parse(opt_path, is_train=True):
|
| 9 |
+
# remove comments starting with '//'
|
| 10 |
+
json_str = ''
|
| 11 |
+
with open(opt_path, 'r') as f:
|
| 12 |
+
for line in f:
|
| 13 |
+
line = line.split('//')[0] + '\n'
|
| 14 |
+
json_str += line
|
| 15 |
+
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
|
| 16 |
+
|
| 17 |
+
opt['is_train'] = is_train
|
| 18 |
+
scale = opt['scale']
|
| 19 |
+
|
| 20 |
+
# datasets
|
| 21 |
+
for phase, dataset in opt['datasets'].items():
|
| 22 |
+
phase = phase.split('_')[0]
|
| 23 |
+
dataset['phase'] = phase
|
| 24 |
+
dataset['scale'] = scale
|
| 25 |
+
is_lmdb = False
|
| 26 |
+
if 'dataroot_HR' in dataset and dataset['dataroot_HR'] is not None:
|
| 27 |
+
dataset['dataroot_HR'] = os.path.expanduser(dataset['dataroot_HR'])
|
| 28 |
+
if dataset['dataroot_HR'].endswith('lmdb'):
|
| 29 |
+
is_lmdb = True
|
| 30 |
+
if 'dataroot_HR_bg' in dataset and dataset['dataroot_HR_bg'] is not None:
|
| 31 |
+
dataset['dataroot_HR_bg'] = os.path.expanduser(dataset['dataroot_HR_bg'])
|
| 32 |
+
if 'dataroot_LR' in dataset and dataset['dataroot_LR'] is not None:
|
| 33 |
+
dataset['dataroot_LR'] = os.path.expanduser(dataset['dataroot_LR'])
|
| 34 |
+
if dataset['dataroot_LR'].endswith('lmdb'):
|
| 35 |
+
is_lmdb = True
|
| 36 |
+
dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
|
| 37 |
+
|
| 38 |
+
if phase == 'train' and 'subset_file' in dataset and dataset['subset_file'] is not None:
|
| 39 |
+
dataset['subset_file'] = os.path.expanduser(dataset['subset_file'])
|
| 40 |
+
|
| 41 |
+
# path
|
| 42 |
+
for key, path in opt['path'].items():
|
| 43 |
+
if path and key in opt['path']:
|
| 44 |
+
opt['path'][key] = os.path.expanduser(path)
|
| 45 |
+
if is_train:
|
| 46 |
+
experiments_root = os.path.join(opt['path']['root'], 'experiments', opt['name'])
|
| 47 |
+
opt['path']['experiments_root'] = experiments_root
|
| 48 |
+
opt['path']['models'] = os.path.join(experiments_root, 'models')
|
| 49 |
+
opt['path']['training_state'] = os.path.join(experiments_root, 'training_state')
|
| 50 |
+
opt['path']['log'] = experiments_root
|
| 51 |
+
opt['path']['val_images'] = os.path.join(experiments_root, 'val_images')
|
| 52 |
+
|
| 53 |
+
# change some options for debug mode
|
| 54 |
+
if 'debug' in opt['name']:
|
| 55 |
+
opt['train']['val_freq'] = 8
|
| 56 |
+
opt['logger']['print_freq'] = 2
|
| 57 |
+
opt['logger']['save_checkpoint_freq'] = 8
|
| 58 |
+
opt['train']['lr_decay_iter'] = 10
|
| 59 |
+
else: # test
|
| 60 |
+
results_root = os.path.join(opt['path']['root'], 'results', opt['name'])
|
| 61 |
+
opt['path']['results_root'] = results_root
|
| 62 |
+
opt['path']['log'] = results_root
|
| 63 |
+
|
| 64 |
+
# network
|
| 65 |
+
opt['network_G']['scale'] = scale
|
| 66 |
+
|
| 67 |
+
# export CUDA_VISIBLE_DEVICES
|
| 68 |
+
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
| 69 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
| 70 |
+
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
| 71 |
+
|
| 72 |
+
return opt
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class NoneDict(dict):
|
| 76 |
+
def __missing__(self, key):
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# convert to NoneDict, which return None for missing key.
|
| 81 |
+
def dict_to_nonedict(opt):
|
| 82 |
+
if isinstance(opt, dict):
|
| 83 |
+
new_opt = dict()
|
| 84 |
+
for key, sub_opt in opt.items():
|
| 85 |
+
new_opt[key] = dict_to_nonedict(sub_opt)
|
| 86 |
+
return NoneDict(**new_opt)
|
| 87 |
+
elif isinstance(opt, list):
|
| 88 |
+
return [dict_to_nonedict(sub_opt) for sub_opt in opt]
|
| 89 |
+
else:
|
| 90 |
+
return opt
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def dict2str(opt, indent_l=1):
|
| 94 |
+
'''dict to string for logger'''
|
| 95 |
+
msg = ''
|
| 96 |
+
for k, v in opt.items():
|
| 97 |
+
if isinstance(v, dict):
|
| 98 |
+
msg += ' ' * (indent_l * 2) + k + ':[\n'
|
| 99 |
+
msg += dict2str(v, indent_l + 1)
|
| 100 |
+
msg += ' ' * (indent_l * 2) + ']\n'
|
| 101 |
+
else:
|
| 102 |
+
msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
|
| 103 |
+
return msg
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def check_resume(opt):
|
| 107 |
+
'''Check resume states and pretrain_model paths'''
|
| 108 |
+
logger = logging.getLogger('base')
|
| 109 |
+
if opt['path']['resume_state']:
|
| 110 |
+
if opt['path']['pretrain_model_G'] or opt['path']['pretrain_model_D']:
|
| 111 |
+
logger.warning('pretrain_model path will be ignored when resuming training.')
|
| 112 |
+
|
| 113 |
+
state_idx = osp.basename(opt['path']['resume_state']).split('.')[0]
|
| 114 |
+
opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
|
| 115 |
+
'{}_G.pth'.format(state_idx))
|
| 116 |
+
logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
|
| 117 |
+
if 'gan' in opt['model']:
|
| 118 |
+
opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
|
| 119 |
+
'{}_D.pth'.format(state_idx))
|
| 120 |
+
logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
|
esrgan_plus/codes/options/test/test_ESRGANplus.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "nESRGAN+_x4"
|
| 3 |
+
, "suffix": "_ESRGAN"
|
| 4 |
+
, "model": "srragan"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"test_1": { // the 1st test dataset
|
| 10 |
+
"name": "set5"
|
| 11 |
+
, "mode": "LRHR"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 13 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 14 |
+
}
|
| 15 |
+
, "test_2": { // the 2nd test dataset
|
| 16 |
+
"name": "set14"
|
| 17 |
+
, "mode": "LRHR"
|
| 18 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
|
| 19 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
, "path": {
|
| 24 |
+
"root": "/home/carraz/nESRGANplus"
|
| 25 |
+
, "pretrain_model_G": "../experiments/pretrained_models/RRDB_ESRGAN_x4.pth"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
, "network_G": {
|
| 29 |
+
"which_model_G": "RRDB_net" // RRDB_net | sr_resnet
|
| 30 |
+
, "norm_type": null
|
| 31 |
+
, "mode": "CNA"
|
| 32 |
+
, "nf": 64
|
| 33 |
+
, "nb": 23
|
| 34 |
+
, "in_nc": 3
|
| 35 |
+
, "out_nc": 3
|
| 36 |
+
|
| 37 |
+
, "gc": 32
|
| 38 |
+
, "group": 1
|
| 39 |
+
}
|
| 40 |
+
}
|
esrgan_plus/codes/options/test/test_SRGAN.json
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "SRGAN"
|
| 3 |
+
, "suffix": "_SRGAN"
|
| 4 |
+
, "model": "srgan"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"test_1": { // the 1st test dataset
|
| 10 |
+
"name": "set5"
|
| 11 |
+
, "mode": "LRHR"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 13 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 14 |
+
}
|
| 15 |
+
, "test_2": { // the 2nd test dataset
|
| 16 |
+
"name": "set14"
|
| 17 |
+
, "mode": "LRHR"
|
| 18 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
|
| 19 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
, "path": {
|
| 24 |
+
"root": "/home/carraz/nESRGANplus"
|
| 25 |
+
, "pretrain_model_G": "../experiments/pretrained_models/SRGAN_bicx4_303_505.pth"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
, "network_G": {
|
| 29 |
+
"which_model_G": "sr_resnet"
|
| 30 |
+
, "norm_type": null
|
| 31 |
+
, "mode": "CNA"
|
| 32 |
+
, "nf": 64
|
| 33 |
+
, "nb": 16
|
| 34 |
+
, "in_nc": 3
|
| 35 |
+
, "out_nc": 3
|
| 36 |
+
}
|
| 37 |
+
}
|
esrgan_plus/codes/options/test/test_SRResNet.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "SRResNet_bicx4_in3nf64nb16"
|
| 3 |
+
, "suffix": null
|
| 4 |
+
, "model": "sr"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"test_1": { // the 1st test dataset
|
| 10 |
+
"name": "set5"
|
| 11 |
+
, "mode": "LRHR"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 13 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 14 |
+
}
|
| 15 |
+
, "test_2": { // the 2nd test dataset
|
| 16 |
+
"name": "set14"
|
| 17 |
+
, "mode": "LRHR"
|
| 18 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
|
| 19 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
, "path": {
|
| 24 |
+
"root": "/home/carraz/nESRGANplus"
|
| 25 |
+
, "pretrain_model_G": "../experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
, "network_G": {
|
| 29 |
+
"which_model_G": "sr_resnet" // RRDB_net | sr_resnet
|
| 30 |
+
, "norm_type": null
|
| 31 |
+
, "mode": "CNA"
|
| 32 |
+
, "nf": 64
|
| 33 |
+
, "nb": 16
|
| 34 |
+
, "in_nc": 3
|
| 35 |
+
, "out_nc": 3
|
| 36 |
+
|
| 37 |
+
, "gc": 32
|
| 38 |
+
, "group": 1
|
| 39 |
+
}
|
| 40 |
+
}
|
esrgan_plus/codes/options/test/test_sr.json
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "RRDB_PSNR_x4"
|
| 3 |
+
, "suffix": null
|
| 4 |
+
, "model": "sr"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"test_1": { // the 1st test dataset
|
| 10 |
+
"name": "set5"
|
| 11 |
+
, "mode": "LRHR"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 13 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 14 |
+
}
|
| 15 |
+
, "test_2": { // the 2nd test dataset
|
| 16 |
+
"name": "set14"
|
| 17 |
+
, "mode": "LRHR"
|
| 18 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set14/Set14"
|
| 19 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set14/Set14_bicLRx4"
|
| 20 |
+
}
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
, "path": {
|
| 24 |
+
"root": "/home/carraz/nESRGANplus"
|
| 25 |
+
, "pretrain_model_G": "../experiments/pretrained_models/RRDB_PSNR_x4.pth"
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
, "network_G": {
|
| 29 |
+
"which_model_G": "RRDB_net" // RRDB_net | sr_resnet
|
| 30 |
+
, "norm_type": null
|
| 31 |
+
, "mode": "CNA"
|
| 32 |
+
, "nf": 64
|
| 33 |
+
, "nb": 23
|
| 34 |
+
, "in_nc": 3
|
| 35 |
+
, "out_nc": 3
|
| 36 |
+
|
| 37 |
+
, "gc": 32
|
| 38 |
+
, "group": 1
|
| 39 |
+
}
|
| 40 |
+
}
|
esrgan_plus/codes/options/train/train_ESRGANplus.json
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "nESRGANplus_x4_DIV2K",
|
| 3 |
+
"use_tb_logger": true,
|
| 4 |
+
"model": "srragan",
|
| 5 |
+
"scale": 4,
|
| 6 |
+
"gpu_ids": [
|
| 7 |
+
0
|
| 8 |
+
],
|
| 9 |
+
"datasets": {
|
| 10 |
+
"train": {
|
| 11 |
+
"name": "DIV2K",
|
| 12 |
+
"mode": "LRHR",
|
| 13 |
+
"dataroot_HR": "/content/gdrive/My Drive/DIV2K_train_HR_sub",
|
| 14 |
+
"dataroot_LR": "/content/gdrive/My Drive/DIV2K_train_LR_sub_bicLRx4",
|
| 15 |
+
"subset_file": null,
|
| 16 |
+
"use_shuffle": true,
|
| 17 |
+
"n_workers": 8,
|
| 18 |
+
"batch_size": 16,
|
| 19 |
+
"HR_size": 128,
|
| 20 |
+
"use_flip": true,
|
| 21 |
+
"use_rot": true
|
| 22 |
+
},
|
| 23 |
+
"val": {
|
| 24 |
+
"name": "val_set14_part",
|
| 25 |
+
"mode": "LRHR",
|
| 26 |
+
"dataroot_HR": "/content/gdrive/My Drive/ESRGAN/Set14",
|
| 27 |
+
"dataroot_LR": "/content/gdrive/My Drive/ESRGAN/Set14_LR_sub_bicLRx4"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"path": {
|
| 31 |
+
"root": "/content/gdrive/My Drive/ESRGAN/BasicSR",
|
| 32 |
+
"resume_state": "/content/gdrive/My Drive/ESRGAN/BasicSR/experiments/002_RRDB_ESRGAN_x4_DIV2K/training_state/495000.state",
|
| 33 |
+
"pretrain_model_G": "/content/gdrive/My Drive/ESRGAN/RRDB_PSNR_x4.pth"
|
| 34 |
+
},
|
| 35 |
+
"network_G": {
|
| 36 |
+
"which_model_G": "RRDB_net",
|
| 37 |
+
"norm_type": null,
|
| 38 |
+
"mode": "CNA",
|
| 39 |
+
"nf": 64,
|
| 40 |
+
"nb": 23,
|
| 41 |
+
"in_nc": 3,
|
| 42 |
+
"out_nc": 3,
|
| 43 |
+
"gc": 32,
|
| 44 |
+
"group": 1
|
| 45 |
+
},
|
| 46 |
+
"network_D": {
|
| 47 |
+
"which_model_D": "discriminator_vgg_128",
|
| 48 |
+
"norm_type": "batch",
|
| 49 |
+
"act_type": "leakyrelu",
|
| 50 |
+
"mode": "CNA",
|
| 51 |
+
"nf": 64,
|
| 52 |
+
"in_nc": 3
|
| 53 |
+
},
|
| 54 |
+
"train": {
|
| 55 |
+
"lr_G": 0.0001,
|
| 56 |
+
"weight_decay_G": 0,
|
| 57 |
+
"beta1_G": 0.9,
|
| 58 |
+
"lr_D": 0.0001,
|
| 59 |
+
"weight_decay_D": 0,
|
| 60 |
+
"beta1_D": 0.9,
|
| 61 |
+
"lr_scheme": "MultiStepLR",
|
| 62 |
+
"lr_steps": [
|
| 63 |
+
50000,
|
| 64 |
+
100000,
|
| 65 |
+
200000,
|
| 66 |
+
300000
|
| 67 |
+
],
|
| 68 |
+
"lr_gamma": 0.5,
|
| 69 |
+
"pixel_criterion": "l1",
|
| 70 |
+
"pixel_weight": 0.01,
|
| 71 |
+
"feature_criterion": "l1",
|
| 72 |
+
"feature_weight": 1,
|
| 73 |
+
"gan_type": "vanilla",
|
| 74 |
+
"gan_weight": 0.005,
|
| 75 |
+
"manual_seed": 0,
|
| 76 |
+
"niter": 500000.0,
|
| 77 |
+
"val_freq": 500.0
|
| 78 |
+
},
|
| 79 |
+
"logger": {
|
| 80 |
+
"print_freq": 50,
|
| 81 |
+
"save_checkpoint_freq": 500.0
|
| 82 |
+
}
|
| 83 |
+
}
|
esrgan_plus/codes/options/train/train_SRGAN.json
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Not total the same as SRGAN in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
| 2 |
+
{
|
| 3 |
+
"name": "debug_002_SRGAN_x4_DIV2K" // please remove "debug_" during training
|
| 4 |
+
, "use_tb_logger": true
|
| 5 |
+
, "model":"srgan"
|
| 6 |
+
, "scale": 4
|
| 7 |
+
, "gpu_ids": [0]
|
| 8 |
+
|
| 9 |
+
, "datasets": {
|
| 10 |
+
"train": {
|
| 11 |
+
"name": "DIV2K"
|
| 12 |
+
, "mode": "LRHR"
|
| 13 |
+
, "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
|
| 14 |
+
, "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
|
| 15 |
+
, "subset_file": null
|
| 16 |
+
, "use_shuffle": true
|
| 17 |
+
, "n_workers": 8
|
| 18 |
+
, "batch_size": 16
|
| 19 |
+
, "HR_size": 128
|
| 20 |
+
, "use_flip": true
|
| 21 |
+
, "use_rot": true
|
| 22 |
+
}
|
| 23 |
+
, "val": {
|
| 24 |
+
"name": "val_set14_part"
|
| 25 |
+
, "mode": "LRHR"
|
| 26 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set14_part/Set14"
|
| 27 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set14_part/Set14_bicLRx4"
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
, "path": {
|
| 32 |
+
"root": "/home/carraz/nESRGANplus"
|
| 33 |
+
// , "resume_state": "../experiments/debug_002_SRGAN_x4_DIV2K/training_state/16.state"
|
| 34 |
+
, "pretrain_model_G": "../experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth"
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
, "network_G": {
|
| 38 |
+
"which_model_G": "sr_resnet" // RRDB_net | sr_resnet
|
| 39 |
+
, "norm_type": null
|
| 40 |
+
, "mode": "CNA"
|
| 41 |
+
, "nf": 64
|
| 42 |
+
, "nb": 16
|
| 43 |
+
, "in_nc": 3
|
| 44 |
+
, "out_nc": 3
|
| 45 |
+
}
|
| 46 |
+
, "network_D": {
|
| 47 |
+
"which_model_D": "discriminator_vgg_128"
|
| 48 |
+
, "norm_type": "batch"
|
| 49 |
+
, "act_type": "leakyrelu"
|
| 50 |
+
, "mode": "CNA"
|
| 51 |
+
, "nf": 64
|
| 52 |
+
, "in_nc": 3
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
, "train": {
|
| 56 |
+
"lr_G": 1e-4
|
| 57 |
+
, "weight_decay_G": 0
|
| 58 |
+
, "beta1_G": 0.9
|
| 59 |
+
, "lr_D": 1e-4
|
| 60 |
+
, "weight_decay_D": 0
|
| 61 |
+
, "beta1_D": 0.9
|
| 62 |
+
, "lr_scheme": "MultiStepLR"
|
| 63 |
+
, "lr_steps": [50000, 100000, 200000, 300000]
|
| 64 |
+
, "lr_gamma": 0.5
|
| 65 |
+
|
| 66 |
+
, "pixel_criterion": "l1"
|
| 67 |
+
, "pixel_weight": 1e-2
|
| 68 |
+
, "feature_criterion": "l1"
|
| 69 |
+
, "feature_weight": 1
|
| 70 |
+
, "gan_type": "vanilla"
|
| 71 |
+
, "gan_weight": 5e-3
|
| 72 |
+
|
| 73 |
+
//for wgan-gp
|
| 74 |
+
// , "D_update_ratio": 1
|
| 75 |
+
// , "D_init_iters": 0
|
| 76 |
+
// , "gp_weigth": 10
|
| 77 |
+
|
| 78 |
+
, "manual_seed": 0
|
| 79 |
+
, "niter": 5e5
|
| 80 |
+
, "val_freq": 5e3
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
, "logger": {
|
| 84 |
+
"print_freq": 200
|
| 85 |
+
, "save_checkpoint_freq": 5e3
|
| 86 |
+
}
|
| 87 |
+
}
|
esrgan_plus/codes/options/train/train_SRResNet.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Not total the same as SRResNet in <Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network>
|
| 2 |
+
// With 16 Residual blocks w/o BN
|
| 3 |
+
{
|
| 4 |
+
"name": "debug_001_SRResNet_PSNR_x4_DIV2K" // please remove "debug_" during training
|
| 5 |
+
, "use_tb_logger": true
|
| 6 |
+
, "model":"sr"
|
| 7 |
+
, "scale": 4
|
| 8 |
+
, "gpu_ids": [0]
|
| 9 |
+
|
| 10 |
+
, "datasets": {
|
| 11 |
+
"train": {
|
| 12 |
+
"name": "DIV2K"
|
| 13 |
+
, "mode": "LRHR"
|
| 14 |
+
, "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
|
| 15 |
+
, "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
|
| 16 |
+
, "subset_file": null
|
| 17 |
+
, "use_shuffle": true
|
| 18 |
+
, "n_workers": 8
|
| 19 |
+
, "batch_size": 16
|
| 20 |
+
, "HR_size": 128 // 128 | 192
|
| 21 |
+
, "use_flip": true
|
| 22 |
+
, "use_rot": true
|
| 23 |
+
}
|
| 24 |
+
, "val": {
|
| 25 |
+
"name": "val_set5"
|
| 26 |
+
, "mode": "LRHR"
|
| 27 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 28 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
, "path": {
|
| 33 |
+
"root": "/home/carraz/nESRGANplus"
|
| 34 |
+
// , "resume_state": "../experiments/debug_001_RRDB_PSNR_x4_DIV2K/training_state/200.state"
|
| 35 |
+
, "pretrain_model_G": null
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
, "network_G": {
|
| 39 |
+
"which_model_G": "sr_resnet" // RRDB_net | sr_resnet
|
| 40 |
+
, "norm_type": null
|
| 41 |
+
, "mode": "CNA"
|
| 42 |
+
, "nf": 64
|
| 43 |
+
, "nb": 16
|
| 44 |
+
, "in_nc": 3
|
| 45 |
+
, "out_nc": 3
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
, "train": {
|
| 49 |
+
"lr_G": 2e-4
|
| 50 |
+
, "lr_scheme": "MultiStepLR"
|
| 51 |
+
, "lr_steps": [200000, 400000, 600000, 800000]
|
| 52 |
+
, "lr_gamma": 0.5
|
| 53 |
+
|
| 54 |
+
, "pixel_criterion": "l1"
|
| 55 |
+
, "pixel_weight": 1.0
|
| 56 |
+
, "val_freq": 5e3
|
| 57 |
+
|
| 58 |
+
, "manual_seed": 0
|
| 59 |
+
, "niter": 1e6
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
, "logger": {
|
| 63 |
+
"print_freq": 200
|
| 64 |
+
, "save_checkpoint_freq": 5e3
|
| 65 |
+
}
|
| 66 |
+
}
|
esrgan_plus/codes/options/train/train_sftgan.json
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "debug_003_SFTGANx4_OST" // please remove "debug_" during training
|
| 3 |
+
, "use_tb_logger": false
|
| 4 |
+
, "model": "sftgan"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"train": {
|
| 10 |
+
"name": "OST"
|
| 11 |
+
, "mode": "LRHRseg_bg"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/OST/train/img"
|
| 13 |
+
, "dataroot_HR_bg": "/home/carraz/datasets/DIV2K800/DIV2K800_sub"
|
| 14 |
+
, "dataroot_LR": null
|
| 15 |
+
, "subset_file": null
|
| 16 |
+
, "use_shuffle": true
|
| 17 |
+
, "n_workers": 8
|
| 18 |
+
, "batch_size": 16
|
| 19 |
+
, "HR_size": 96
|
| 20 |
+
, "use_flip": true
|
| 21 |
+
, "use_rot": false
|
| 22 |
+
}
|
| 23 |
+
, "val": {
|
| 24 |
+
"name": "val_OST300_part"
|
| 25 |
+
, "mode": "LRHRseg_bg"
|
| 26 |
+
, "dataroot_HR": "/home/carraz/datasets/OST/val/img"
|
| 27 |
+
, "dataroot_LR": null
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
, "path": {
|
| 32 |
+
"root": "/home/carraz/nESRGANplus"
|
| 33 |
+
, "resume_state": null
|
| 34 |
+
, "pretrain_model_G": "../experiments/pretrained_models/sft_net_ini.pth"
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
, "network_G": {
|
| 38 |
+
"which_model_G": "sft_arch"
|
| 39 |
+
}
|
| 40 |
+
, "network_D": {
|
| 41 |
+
"which_model_D": "dis_acd"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
, "train": {
|
| 45 |
+
"lr_G": 1e-4
|
| 46 |
+
, "weight_decay_G": 0
|
| 47 |
+
, "beta1_G": 0.9
|
| 48 |
+
, "lr_D": 1e-4
|
| 49 |
+
, "weight_decay_D": 0
|
| 50 |
+
, "beta1_D": 0.9
|
| 51 |
+
, "lr_scheme": "MultiStepLR"
|
| 52 |
+
, "lr_steps": [50000, 100000, 150000, 200000]
|
| 53 |
+
, "lr_gamma": 0.5
|
| 54 |
+
|
| 55 |
+
, "pixel_criterion": "l1"
|
| 56 |
+
, "pixel_weight": 0
|
| 57 |
+
, "feature_criterion": "l1"
|
| 58 |
+
, "feature_weight": 1
|
| 59 |
+
, "gan_type": "vanilla"
|
| 60 |
+
, "gan_weight": 5e-3
|
| 61 |
+
|
| 62 |
+
//for wgan-gp
|
| 63 |
+
// , "D_update_ratio": 1
|
| 64 |
+
// , "D_init_iters": 0
|
| 65 |
+
// , "gp_weigth": 10
|
| 66 |
+
|
| 67 |
+
, "manual_seed": 0
|
| 68 |
+
, "niter": 6e5
|
| 69 |
+
, "val_freq": 2e3
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
, "logger": {
|
| 73 |
+
"print_freq": 200
|
| 74 |
+
, "save_checkpoint_freq": 2e3
|
| 75 |
+
}
|
| 76 |
+
}
|
esrgan_plus/codes/options/train/train_sr.json
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "debug_001_RRDB_PSNR_x4_DIV2K" // please remove "debug_" during training
|
| 3 |
+
, "use_tb_logger": true
|
| 4 |
+
, "model":"sr"
|
| 5 |
+
, "scale": 4
|
| 6 |
+
, "gpu_ids": [0]
|
| 7 |
+
|
| 8 |
+
, "datasets": {
|
| 9 |
+
"train": {
|
| 10 |
+
"name": "DIV2K"
|
| 11 |
+
, "mode": "LRHR"
|
| 12 |
+
, "dataroot_HR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub.lmdb"
|
| 13 |
+
, "dataroot_LR": "/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4.lmdb"
|
| 14 |
+
, "subset_file": null
|
| 15 |
+
, "use_shuffle": true
|
| 16 |
+
, "n_workers": 8
|
| 17 |
+
, "batch_size": 16
|
| 18 |
+
, "HR_size": 128 // 128 | 192
|
| 19 |
+
, "use_flip": true
|
| 20 |
+
, "use_rot": true
|
| 21 |
+
}
|
| 22 |
+
, "val": {
|
| 23 |
+
"name": "val_set5"
|
| 24 |
+
, "mode": "LRHR"
|
| 25 |
+
, "dataroot_HR": "/home/carraz/datasets/val_set5/Set5"
|
| 26 |
+
, "dataroot_LR": "/home/carraz/datasets/val_set5/Set5_bicLRx4"
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
, "path": {
|
| 31 |
+
"root": "/home/carraz/nESRGANplus"
|
| 32 |
+
// , "resume_state": "../experiments/debug_001_RRDB_PSNR_x4_DIV2K/training_state/200.state"
|
| 33 |
+
, "pretrain_model_G": null
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
, "network_G": {
|
| 37 |
+
"which_model_G": "RRDB_net" // RRDB_net | sr_resnet
|
| 38 |
+
, "norm_type": null
|
| 39 |
+
, "mode": "CNA"
|
| 40 |
+
, "nf": 64
|
| 41 |
+
, "nb": 23
|
| 42 |
+
, "in_nc": 3
|
| 43 |
+
, "out_nc": 3
|
| 44 |
+
, "gc": 32
|
| 45 |
+
, "group": 1
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
, "train": {
|
| 49 |
+
"lr_G": 2e-4
|
| 50 |
+
, "lr_scheme": "MultiStepLR"
|
| 51 |
+
, "lr_steps": [200000, 400000, 600000, 800000]
|
| 52 |
+
, "lr_gamma": 0.5
|
| 53 |
+
|
| 54 |
+
, "pixel_criterion": "l1"
|
| 55 |
+
, "pixel_weight": 1.0
|
| 56 |
+
, "val_freq": 5e3
|
| 57 |
+
|
| 58 |
+
, "manual_seed": 0
|
| 59 |
+
, "niter": 1e6
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
, "logger": {
|
| 63 |
+
"print_freq": 200
|
| 64 |
+
, "save_checkpoint_freq": 5e3
|
| 65 |
+
}
|
| 66 |
+
}
|
esrgan_plus/codes/scripts/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scripts
|
| 2 |
+
We provide some useful scripts here.
|
| 3 |
+
|
| 4 |
+
## List
|
| 5 |
+
|
| 6 |
+
| Name | Description |
|
| 7 |
+
|:---:|:---:|
|
| 8 |
+
| back projection | `Matlab` codes for back projection |
|
esrgan_plus/codes/scripts/back_projection/backprojection.m
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function [im_h] = backprojection(im_h, im_l, maxIter)
|
| 2 |
+
|
| 3 |
+
[row_l, col_l,~] = size(im_l);
|
| 4 |
+
[row_h, col_h,~] = size(im_h);
|
| 5 |
+
|
| 6 |
+
p = fspecial('gaussian', 5, 1);
|
| 7 |
+
p = p.^2;
|
| 8 |
+
p = p./sum(p(:));
|
| 9 |
+
|
| 10 |
+
im_l = double(im_l);
|
| 11 |
+
im_h = double(im_h);
|
| 12 |
+
|
| 13 |
+
for ii = 1:maxIter
|
| 14 |
+
im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
|
| 15 |
+
im_diff = im_l - im_l_s;
|
| 16 |
+
im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
|
| 17 |
+
im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
|
| 18 |
+
im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
|
| 19 |
+
im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
|
| 20 |
+
end
|
esrgan_plus/codes/scripts/back_projection/main_bp.m
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
clear; close all; clc;
|
| 2 |
+
|
| 3 |
+
LR_folder = './LR'; % LR
|
| 4 |
+
preout_folder = './results'; % pre output
|
| 5 |
+
save_folder = './results_20bp';
|
| 6 |
+
filepaths = dir(fullfile(preout_folder, '*.png'));
|
| 7 |
+
max_iter = 20;
|
| 8 |
+
|
| 9 |
+
if ~ exist(save_folder, 'dir')
|
| 10 |
+
mkdir(save_folder);
|
| 11 |
+
end
|
| 12 |
+
|
| 13 |
+
for idx_im = 1:length(filepaths)
|
| 14 |
+
fprintf([num2str(idx_im) '\n']);
|
| 15 |
+
im_name = filepaths(idx_im).name;
|
| 16 |
+
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
| 17 |
+
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
| 18 |
+
%tic
|
| 19 |
+
im_out = backprojection(im_out, im_LR, max_iter);
|
| 20 |
+
%toc
|
| 21 |
+
imwrite(im_out, fullfile(save_folder, im_name));
|
| 22 |
+
end
|
esrgan_plus/codes/scripts/back_projection/main_reverse_filter.m
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
clear; close all; clc;
|
| 2 |
+
|
| 3 |
+
LR_folder = './LR'; % LR
|
| 4 |
+
preout_folder = './results'; % pre output
|
| 5 |
+
save_folder = './results_20if';
|
| 6 |
+
filepaths = dir(fullfile(preout_folder, '*.png'));
|
| 7 |
+
max_iter = 20;
|
| 8 |
+
|
| 9 |
+
if ~ exist(save_folder, 'dir')
|
| 10 |
+
mkdir(save_folder);
|
| 11 |
+
end
|
| 12 |
+
|
| 13 |
+
for idx_im = 1:length(filepaths)
|
| 14 |
+
fprintf([num2str(idx_im) '\n']);
|
| 15 |
+
im_name = filepaths(idx_im).name;
|
| 16 |
+
im_LR = im2double(imread(fullfile(LR_folder, im_name)));
|
| 17 |
+
im_out = im2double(imread(fullfile(preout_folder, im_name)));
|
| 18 |
+
J = imresize(im_LR,4,'bicubic');
|
| 19 |
+
%tic
|
| 20 |
+
for m = 1:max_iter
|
| 21 |
+
im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
|
| 22 |
+
end
|
| 23 |
+
%toc
|
| 24 |
+
imwrite(im_out, fullfile(save_folder, im_name));
|
| 25 |
+
end
|
esrgan_plus/codes/scripts/color2gray.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
import sys
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
from data.util import bgr2ycbcr
|
| 9 |
+
from utils.progress_bar import ProgressBar
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
"""A multi-thread tool for converting RGB images to gary/Y images."""
|
| 14 |
+
|
| 15 |
+
input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800'
|
| 16 |
+
save_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_gray'
|
| 17 |
+
mode = 'gray' # 'gray' | 'y': Y channel in YCbCr space
|
| 18 |
+
compression_level = 3 # 3 is the default value in cv2
|
| 19 |
+
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
| 20 |
+
# compression time. If read raw images during training, use 0 for faster IO speed.
|
| 21 |
+
n_thread = 20 # thread number
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(save_folder):
|
| 24 |
+
os.makedirs(save_folder)
|
| 25 |
+
print('mkdir [{:s}] ...'.format(save_folder))
|
| 26 |
+
else:
|
| 27 |
+
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
|
| 28 |
+
sys.exit(1)
|
| 29 |
+
# print('Parent process {:d}.'.format(os.getpid()))
|
| 30 |
+
|
| 31 |
+
img_list = []
|
| 32 |
+
for root, _, file_list in sorted(os.walk(input_folder)):
|
| 33 |
+
path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder
|
| 34 |
+
img_list.extend(path)
|
| 35 |
+
|
| 36 |
+
def update(arg):
|
| 37 |
+
pbar.update(arg)
|
| 38 |
+
|
| 39 |
+
pbar = ProgressBar(len(img_list))
|
| 40 |
+
|
| 41 |
+
pool = Pool(n_thread)
|
| 42 |
+
for path in img_list:
|
| 43 |
+
pool.apply_async(worker, args=(path, save_folder, mode, compression_level), callback=update)
|
| 44 |
+
pool.close()
|
| 45 |
+
pool.join()
|
| 46 |
+
print('All subprocesses done.')
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def worker(path, save_folder, mode, compression_level):
|
| 50 |
+
img_name = os.path.basename(path)
|
| 51 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR
|
| 52 |
+
if mode == 'gray':
|
| 53 |
+
img_y = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 54 |
+
else:
|
| 55 |
+
img_y = bgr2ycbcr(img, only_y=True)
|
| 56 |
+
cv2.imwrite(
|
| 57 |
+
os.path.join(save_folder, img_name), img_y,
|
| 58 |
+
[cv2.IMWRITE_PNG_COMPRESSION, compression_level])
|
| 59 |
+
return 'Processing {:s} ...'.format(img_name)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == '__main__':
|
| 63 |
+
main()
|
esrgan_plus/codes/scripts/create_lmdb.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path
|
| 3 |
+
import glob
|
| 4 |
+
import pickle
|
| 5 |
+
import lmdb
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 9 |
+
from utils.progress_bar import ProgressBar
|
| 10 |
+
|
| 11 |
+
# configurations
|
| 12 |
+
img_folder = '/home/carraz/datasets/DIV2K800/DIV2K800/*' # glob matching pattern
|
| 13 |
+
lmdb_save_path = '/home/carraz/datasets/DIV2K800/DIV2K800.lmdb' # must end with .lmdb
|
| 14 |
+
mode = 1 # 1 for small data (more memory), 2 for large data (less memory)
|
| 15 |
+
|
| 16 |
+
img_list = sorted(glob.glob(img_folder))
|
| 17 |
+
|
| 18 |
+
print('Read images...')
|
| 19 |
+
# mode 1 small data, read all imgs
|
| 20 |
+
if mode == 1:
|
| 21 |
+
dataset = [cv2.imread(v, cv2.IMREAD_UNCHANGED) for v in img_list]
|
| 22 |
+
data_size = sum([img.nbytes for img in dataset])
|
| 23 |
+
# mode 2 large data, read imgs later
|
| 24 |
+
elif mode == 2:
|
| 25 |
+
data_size = sum(os.stat(v).st_size for v in img_list)
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError('mode should be 1 or 2')
|
| 28 |
+
|
| 29 |
+
env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
|
| 30 |
+
print('Finish reading {} images.\nWrite lmdb...'.format(len(img_list)))
|
| 31 |
+
|
| 32 |
+
pbar = ProgressBar(len(img_list))
|
| 33 |
+
batch = 3000 # can be modified according to memory usage
|
| 34 |
+
txn = env.begin(write=True) # txn is a Transaction object
|
| 35 |
+
for i, v in enumerate(img_list):
|
| 36 |
+
pbar.update('Write {}'.format(v))
|
| 37 |
+
base_name = os.path.splitext(os.path.basename(v))[0]
|
| 38 |
+
key = base_name.encode('ascii')
|
| 39 |
+
data = dataset[i] if mode == 1 else cv2.imread(v, cv2.IMREAD_UNCHANGED)
|
| 40 |
+
if data.ndim == 2:
|
| 41 |
+
H, W = data.shape
|
| 42 |
+
C = 1
|
| 43 |
+
else:
|
| 44 |
+
H, W, C = data.shape
|
| 45 |
+
meta_key = (base_name + '.meta').encode('ascii')
|
| 46 |
+
meta = '{:d}, {:d}, {:d}'.format(H, W, C)
|
| 47 |
+
# The encode is only essential in Python 3
|
| 48 |
+
txn.put(key, data)
|
| 49 |
+
txn.put(meta_key, meta.encode('ascii'))
|
| 50 |
+
if mode == 2 and i % batch == batch - 1:
|
| 51 |
+
txn.commit()
|
| 52 |
+
txn = env.begin(write=True)
|
| 53 |
+
|
| 54 |
+
txn.commit()
|
| 55 |
+
env.close()
|
| 56 |
+
|
| 57 |
+
print('Finish writing lmdb.')
|
| 58 |
+
|
| 59 |
+
# create keys cache
|
| 60 |
+
keys_cache_file = os.path.join(lmdb_save_path, '_keys_cache.p')
|
| 61 |
+
env = lmdb.open(lmdb_save_path, readonly=True, lock=False, readahead=False, meminit=False)
|
| 62 |
+
with env.begin(write=False) as txn:
|
| 63 |
+
print('Create lmdb keys cache: {}'.format(keys_cache_file))
|
| 64 |
+
keys = [key.decode('ascii') for key, _ in txn.cursor()]
|
| 65 |
+
pickle.dump(keys, open(keys_cache_file, "wb"))
|
| 66 |
+
print('Finish creating lmdb keys cache.')
|
esrgan_plus/codes/scripts/extract_enlarge_patches.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import glob
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
crt_path = os.path.dirname(os.path.realpath(__file__))
|
| 6 |
+
|
| 7 |
+
# configurations
|
| 8 |
+
h_start, h_len = 170, 64
|
| 9 |
+
w_start, w_len = 232, 100
|
| 10 |
+
enlarge_ratio = 3
|
| 11 |
+
line_width = 2
|
| 12 |
+
color = 'yellow'
|
| 13 |
+
|
| 14 |
+
folder = os.path.join(crt_path, './ori/*')
|
| 15 |
+
save_patch_folder = os.path.join(crt_path, './patch')
|
| 16 |
+
save_rect_folder = os.path.join(crt_path, './rect')
|
| 17 |
+
|
| 18 |
+
color_tb = {}
|
| 19 |
+
color_tb['yellow'] = (0, 255, 255)
|
| 20 |
+
color_tb['green'] = (0, 255, 0)
|
| 21 |
+
color_tb['red'] = (0, 0, 255)
|
| 22 |
+
color_tb['magenta'] = (255, 0, 255)
|
| 23 |
+
color_tb['matlab_blue'] = (189, 114, 0)
|
| 24 |
+
color_tb['matlab_orange'] = (25, 83, 217)
|
| 25 |
+
color_tb['matlab_yellow'] = (32, 177, 237)
|
| 26 |
+
color_tb['matlab_purple'] = (142, 47, 126)
|
| 27 |
+
color_tb['matlab_green'] = (48, 172, 119)
|
| 28 |
+
color_tb['matlab_liblue'] = (238, 190, 77)
|
| 29 |
+
color_tb['matlab_brown'] = (47, 20, 162)
|
| 30 |
+
color = color_tb[color]
|
| 31 |
+
img_list = glob.glob(folder)
|
| 32 |
+
images = []
|
| 33 |
+
|
| 34 |
+
# make temp folder
|
| 35 |
+
if not os.path.exists(save_patch_folder):
|
| 36 |
+
os.makedirs(save_patch_folder)
|
| 37 |
+
print('mkdir [{}] ...'.format(save_patch_folder))
|
| 38 |
+
if not os.path.exists(save_rect_folder):
|
| 39 |
+
os.makedirs(save_rect_folder)
|
| 40 |
+
print('mkdir [{}] ...'.format(save_rect_folder))
|
| 41 |
+
|
| 42 |
+
for i, path in enumerate(img_list):
|
| 43 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 44 |
+
base_name = os.path.splitext(os.path.basename(path))[0]
|
| 45 |
+
print(i, base_name)
|
| 46 |
+
# crop patch
|
| 47 |
+
if img.ndim == 2:
|
| 48 |
+
patch = img[h_start:h_start + h_len, w_start:w_start + w_len]
|
| 49 |
+
elif img.ndim == 3:
|
| 50 |
+
patch = img[h_start:h_start + h_len, w_start:w_start + w_len, :]
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError('Wrong image dim [{:d}]'.format(img.ndim))
|
| 53 |
+
|
| 54 |
+
# enlarge patch if necessary
|
| 55 |
+
if enlarge_ratio > 1:
|
| 56 |
+
H, W, _ = patch.shape
|
| 57 |
+
patch = cv2.resize(patch, (W * enlarge_ratio, H * enlarge_ratio), \
|
| 58 |
+
interpolation=cv2.INTER_CUBIC)
|
| 59 |
+
cv2.imwrite(os.path.join(save_patch_folder, base_name + '_patch.png'), patch)
|
| 60 |
+
|
| 61 |
+
# draw rectangle
|
| 62 |
+
img_rect = cv2.rectangle(img, (w_start, h_start), (w_start + w_len, h_start + h_len),
|
| 63 |
+
color, line_width)
|
| 64 |
+
cv2.imwrite(os.path.join(save_rect_folder, base_name + '_rect.png'), img_rect)
|
esrgan_plus/codes/scripts/extract_subimgs_single.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
import sys
|
| 4 |
+
from multiprocessing import Pool
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
+
from utils.progress_bar import ProgressBar
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
"""A multi-thread tool to crop sub imags."""
|
| 13 |
+
input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800'
|
| 14 |
+
save_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub'
|
| 15 |
+
n_thread = 20
|
| 16 |
+
crop_sz = 480
|
| 17 |
+
step = 240
|
| 18 |
+
thres_sz = 48
|
| 19 |
+
compression_level = 3 # 3 is the default value in cv2
|
| 20 |
+
# CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
|
| 21 |
+
# compression time. If read raw images during training, use 0 for faster IO speed.
|
| 22 |
+
|
| 23 |
+
if not os.path.exists(save_folder):
|
| 24 |
+
os.makedirs(save_folder)
|
| 25 |
+
print('mkdir [{:s}] ...'.format(save_folder))
|
| 26 |
+
else:
|
| 27 |
+
print('Folder [{:s}] already exists. Exit...'.format(save_folder))
|
| 28 |
+
sys.exit(1)
|
| 29 |
+
|
| 30 |
+
img_list = []
|
| 31 |
+
for root, _, file_list in sorted(os.walk(input_folder)):
|
| 32 |
+
path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder
|
| 33 |
+
img_list.extend(path)
|
| 34 |
+
|
| 35 |
+
def update(arg):
|
| 36 |
+
pbar.update(arg)
|
| 37 |
+
|
| 38 |
+
pbar = ProgressBar(len(img_list))
|
| 39 |
+
|
| 40 |
+
pool = Pool(n_thread)
|
| 41 |
+
for path in img_list:
|
| 42 |
+
pool.apply_async(worker,
|
| 43 |
+
args=(path, save_folder, crop_sz, step, thres_sz, compression_level),
|
| 44 |
+
callback=update)
|
| 45 |
+
pool.close()
|
| 46 |
+
pool.join()
|
| 47 |
+
print('All subprocesses done.')
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def worker(path, save_folder, crop_sz, step, thres_sz, compression_level):
|
| 51 |
+
img_name = os.path.basename(path)
|
| 52 |
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
| 53 |
+
|
| 54 |
+
n_channels = len(img.shape)
|
| 55 |
+
if n_channels == 2:
|
| 56 |
+
h, w = img.shape
|
| 57 |
+
elif n_channels == 3:
|
| 58 |
+
h, w, c = img.shape
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError('Wrong image shape - {}'.format(n_channels))
|
| 61 |
+
|
| 62 |
+
h_space = np.arange(0, h - crop_sz + 1, step)
|
| 63 |
+
if h - (h_space[-1] + crop_sz) > thres_sz:
|
| 64 |
+
h_space = np.append(h_space, h - crop_sz)
|
| 65 |
+
w_space = np.arange(0, w - crop_sz + 1, step)
|
| 66 |
+
if w - (w_space[-1] + crop_sz) > thres_sz:
|
| 67 |
+
w_space = np.append(w_space, w - crop_sz)
|
| 68 |
+
|
| 69 |
+
index = 0
|
| 70 |
+
for x in h_space:
|
| 71 |
+
for y in w_space:
|
| 72 |
+
index += 1
|
| 73 |
+
if n_channels == 2:
|
| 74 |
+
crop_img = img[x:x + crop_sz, y:y + crop_sz]
|
| 75 |
+
else:
|
| 76 |
+
crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
|
| 77 |
+
crop_img = np.ascontiguousarray(crop_img)
|
| 78 |
+
# var = np.var(crop_img / 255)
|
| 79 |
+
# if var > 0.008:
|
| 80 |
+
# print(img_name, index_str, var)
|
| 81 |
+
cv2.imwrite(
|
| 82 |
+
os.path.join(save_folder, img_name.replace('.png', '_s{:03d}.png'.format(index))),
|
| 83 |
+
crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level])
|
| 84 |
+
return 'Processing {:s} ...'.format(img_name)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == '__main__':
|
| 88 |
+
main()
|
esrgan_plus/codes/scripts/generate_mod_LR_bic.m
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function generate_mod_LR_bic()
|
| 2 |
+
%% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
|
| 3 |
+
|
| 4 |
+
%% set parameters
|
| 5 |
+
% comment the unnecessary line
|
| 6 |
+
input_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub';
|
| 7 |
+
% save_mod_folder = '';
|
| 8 |
+
save_LR_folder = '/home/carraz/datasets/DIV2K800/DIV2K800_sub_bicLRx4';
|
| 9 |
+
% save_bic_folder = '';
|
| 10 |
+
|
| 11 |
+
up_scale = 4;
|
| 12 |
+
mod_scale = 4;
|
| 13 |
+
|
| 14 |
+
if exist('save_mod_folder', 'var')
|
| 15 |
+
if exist(save_mod_folder, 'dir')
|
| 16 |
+
disp(['It will cover ', save_mod_folder]);
|
| 17 |
+
else
|
| 18 |
+
mkdir(save_mod_folder);
|
| 19 |
+
end
|
| 20 |
+
end
|
| 21 |
+
if exist('save_LR_folder', 'var')
|
| 22 |
+
if exist(save_LR_folder, 'dir')
|
| 23 |
+
disp(['It will cover ', save_LR_folder]);
|
| 24 |
+
else
|
| 25 |
+
mkdir(save_LR_folder);
|
| 26 |
+
end
|
| 27 |
+
end
|
| 28 |
+
if exist('save_bic_folder', 'var')
|
| 29 |
+
if exist(save_bic_folder, 'dir')
|
| 30 |
+
disp(['It will cover ', save_bic_folder]);
|
| 31 |
+
else
|
| 32 |
+
mkdir(save_bic_folder);
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
idx = 0;
|
| 37 |
+
filepaths = dir(fullfile(input_folder,'*.*'));
|
| 38 |
+
for i = 1 : length(filepaths)
|
| 39 |
+
[paths,imname,ext] = fileparts(filepaths(i).name);
|
| 40 |
+
if isempty(imname)
|
| 41 |
+
disp('Ignore . folder.');
|
| 42 |
+
elseif strcmp(imname, '.')
|
| 43 |
+
disp('Ignore .. folder.');
|
| 44 |
+
else
|
| 45 |
+
idx = idx + 1;
|
| 46 |
+
str_rlt = sprintf('%d\t%s.\n', idx, imname);
|
| 47 |
+
fprintf(str_rlt);
|
| 48 |
+
% read image
|
| 49 |
+
img = imread(fullfile(input_folder, [imname, ext]));
|
| 50 |
+
img = im2double(img);
|
| 51 |
+
% modcrop
|
| 52 |
+
img = modcrop(img, mod_scale);
|
| 53 |
+
if exist('save_mod_folder', 'var')
|
| 54 |
+
imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
|
| 55 |
+
end
|
| 56 |
+
% LR
|
| 57 |
+
im_LR = imresize(img, 1/up_scale, 'bicubic');
|
| 58 |
+
if exist('save_LR_folder', 'var')
|
| 59 |
+
imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png']));
|
| 60 |
+
end
|
| 61 |
+
% Bicubic
|
| 62 |
+
if exist('save_bic_folder', 'var')
|
| 63 |
+
im_B = imresize(im_LR, up_scale, 'bicubic');
|
| 64 |
+
imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png']));
|
| 65 |
+
end
|
| 66 |
+
end
|
| 67 |
+
end
|
| 68 |
+
end
|
| 69 |
+
|
| 70 |
+
%% modcrop
|
| 71 |
+
function img = modcrop(img, modulo)
|
| 72 |
+
if size(img,3) == 1
|
| 73 |
+
sz = size(img);
|
| 74 |
+
sz = sz - mod(sz, modulo);
|
| 75 |
+
img = img(1:sz(1), 1:sz(2));
|
| 76 |
+
else
|
| 77 |
+
tmpsz = size(img);
|
| 78 |
+
sz = tmpsz(1:2);
|
| 79 |
+
sz = sz - mod(sz, modulo);
|
| 80 |
+
img = img(1:sz(1), 1:sz(2),:);
|
| 81 |
+
end
|
| 82 |
+
end
|
esrgan_plus/codes/scripts/generate_mod_LR_bic.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 5 |
+
from data.util import imresize_np
|
| 6 |
+
|
| 7 |
+
def generate_mod_LR_bic():
|
| 8 |
+
# set parameters
|
| 9 |
+
up_scale = 4
|
| 10 |
+
mod_scale = 4
|
| 11 |
+
# set data dir
|
| 12 |
+
sourcedir = '/data/datasets/img'
|
| 13 |
+
savedir = '/data/datasets/mod'
|
| 14 |
+
|
| 15 |
+
saveHRpath = os.path.join(savedir, 'HR', 'x'+str(mod_scale))
|
| 16 |
+
saveLRpath = os.path.join(savedir, 'LR', 'x'+str(up_scale))
|
| 17 |
+
saveBicpath = os.path.join(savedir, 'Bic', 'x'+str(up_scale))
|
| 18 |
+
|
| 19 |
+
if not os.path.isdir(sourcedir):
|
| 20 |
+
print('Error: No source data found')
|
| 21 |
+
exit(0)
|
| 22 |
+
if not os.path.isdir(savedir):
|
| 23 |
+
os.mkdir(savedir)
|
| 24 |
+
|
| 25 |
+
if not os.path.isdir(os.path.join(savedir, 'HR')):
|
| 26 |
+
os.mkdir(os.path.join(savedir, 'HR'))
|
| 27 |
+
if not os.path.isdir(os.path.join(savedir, 'LR')):
|
| 28 |
+
os.mkdir(os.path.join(savedir, 'LR'))
|
| 29 |
+
if not os.path.isdir(os.path.join(savedir, 'Bic')):
|
| 30 |
+
os.mkdir(os.path.join(savedir, 'Bic'))
|
| 31 |
+
|
| 32 |
+
if not os.path.isdir(saveHRpath):
|
| 33 |
+
os.mkdir(saveHRpath)
|
| 34 |
+
else:
|
| 35 |
+
print('It will cover '+str(saveHRpath))
|
| 36 |
+
|
| 37 |
+
if not os.path.isdir(saveLRpath):
|
| 38 |
+
os.mkdir(saveLRpath)
|
| 39 |
+
else:
|
| 40 |
+
print('It will cover '+str(saveLRpath))
|
| 41 |
+
|
| 42 |
+
if not os.path.isdir(saveBicpath):
|
| 43 |
+
os.mkdir(saveBicpath)
|
| 44 |
+
else:
|
| 45 |
+
print('It will cover '+str(saveBicpath))
|
| 46 |
+
|
| 47 |
+
filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
|
| 48 |
+
num_files = len(filepaths)
|
| 49 |
+
|
| 50 |
+
# prepare data with augementation
|
| 51 |
+
for i in range(num_files):
|
| 52 |
+
filename = filepaths[i]
|
| 53 |
+
print('No.{} -- Processing {}'.format(i, filename))
|
| 54 |
+
# read image
|
| 55 |
+
image = cv2.imread(os.path.join(sourcedir, filename))
|
| 56 |
+
|
| 57 |
+
width = int(np.floor(image.shape[1] / mod_scale))
|
| 58 |
+
height = int(np.floor(image.shape[0] / mod_scale))
|
| 59 |
+
# modcrop
|
| 60 |
+
if len(image.shape) == 3:
|
| 61 |
+
image_HR = image[0:mod_scale*height, 0:mod_scale*width,:]
|
| 62 |
+
else:
|
| 63 |
+
image_HR = image[0:mod_scale*height, 0:mod_scale*width]
|
| 64 |
+
# LR
|
| 65 |
+
image_LR = imresize_np(image_HR, 1/up_scale, True)
|
| 66 |
+
# bic
|
| 67 |
+
image_Bic = imresize_np(image_LR, up_scale, True)
|
| 68 |
+
|
| 69 |
+
cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
|
| 70 |
+
cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
|
| 71 |
+
cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
generate_mod_LR_bic()
|
esrgan_plus/codes/scripts/make_gif_video.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Add text to images, then make gif/video sequence from images.
|
| 3 |
+
|
| 4 |
+
Since the created gif has low quality with color issues, use this script to generate image with
|
| 5 |
+
text and then use `gifski`.
|
| 6 |
+
|
| 7 |
+
Call `ffmpeg` to make video.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os.path
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
crt_path = os.path.dirname(os.path.realpath(__file__))
|
| 16 |
+
|
| 17 |
+
# configurations
|
| 18 |
+
img_name_list = ['x1', 'x2', 'x3', 'x4', 'x5']
|
| 19 |
+
ext = '.png'
|
| 20 |
+
text_list = ['1', '2', '3', '4', '5']
|
| 21 |
+
h_start, h_len = 0, 576
|
| 22 |
+
w_start, w_len = 10, 352
|
| 23 |
+
enlarge_ratio = 1
|
| 24 |
+
txt_pos = (10, 50) # w, h
|
| 25 |
+
font_size = 1.5
|
| 26 |
+
font_thickness = 4
|
| 27 |
+
color = 'red'
|
| 28 |
+
duration = 0.8 # second
|
| 29 |
+
use_imageio = False # use imageio to make gif
|
| 30 |
+
make_video = False # make video using ffmpeg
|
| 31 |
+
|
| 32 |
+
is_crop = True
|
| 33 |
+
if h_start == 0 or w_start == 0:
|
| 34 |
+
is_crop = False # do not crop
|
| 35 |
+
|
| 36 |
+
img_name_list = [x + ext for x in img_name_list]
|
| 37 |
+
input_folder = os.path.join(crt_path, './ori')
|
| 38 |
+
save_folder = os.path.join(crt_path, './ori')
|
| 39 |
+
color_tb = {}
|
| 40 |
+
color_tb['yellow'] = (0, 255, 255)
|
| 41 |
+
color_tb['green'] = (0, 255, 0)
|
| 42 |
+
color_tb['red'] = (0, 0, 255)
|
| 43 |
+
color_tb['magenta'] = (255, 0, 255)
|
| 44 |
+
color_tb['matlab_blue'] = (189, 114, 0)
|
| 45 |
+
color_tb['matlab_orange'] = (25, 83, 217)
|
| 46 |
+
color_tb['matlab_yellow'] = (32, 177, 237)
|
| 47 |
+
color_tb['matlab_purple'] = (142, 47, 126)
|
| 48 |
+
color_tb['matlab_green'] = (48, 172, 119)
|
| 49 |
+
color_tb['matlab_liblue'] = (238, 190, 77)
|
| 50 |
+
color_tb['matlab_brown'] = (47, 20, 162)
|
| 51 |
+
color = color_tb[color]
|
| 52 |
+
|
| 53 |
+
img_list = []
|
| 54 |
+
|
| 55 |
+
# make temp dir
|
| 56 |
+
if not os.path.exists(save_folder):
|
| 57 |
+
os.makedirs(save_folder)
|
| 58 |
+
print('mkdir [{}] ...'.format(save_folder))
|
| 59 |
+
if make_video:
|
| 60 |
+
# tmp folder to save images for video
|
| 61 |
+
tmp_video_folder = os.path.join(crt_path, '_tmp_video')
|
| 62 |
+
if not os.path.exists(tmp_video_folder):
|
| 63 |
+
os.makedirs(tmp_video_folder)
|
| 64 |
+
|
| 65 |
+
idx = 0
|
| 66 |
+
for img_name, write_txt in zip(img_name_list, text_list):
|
| 67 |
+
img = cv2.imread(os.path.join(input_folder, img_name), cv2.IMREAD_UNCHANGED)
|
| 68 |
+
base_name = os.path.splitext(img_name)[0]
|
| 69 |
+
print(base_name)
|
| 70 |
+
# crop image
|
| 71 |
+
if is_crop:
|
| 72 |
+
print('Crop image ...')
|
| 73 |
+
if img.ndim == 2:
|
| 74 |
+
img = img[h_start:h_start + h_len, w_start:w_start + w_len]
|
| 75 |
+
elif img.ndim == 3:
|
| 76 |
+
img = img[h_start:h_start + h_len, w_start:w_start + w_len, :]
|
| 77 |
+
else:
|
| 78 |
+
raise ValueError('Wrong image dim [{:d}]'.format(img.ndim))
|
| 79 |
+
|
| 80 |
+
# enlarge img if necessary
|
| 81 |
+
if enlarge_ratio > 1:
|
| 82 |
+
H, W, _ = img.shape
|
| 83 |
+
img = cv2.resize(img, (W * enlarge_ratio, H * enlarge_ratio), \
|
| 84 |
+
interpolation=cv2.INTER_CUBIC)
|
| 85 |
+
|
| 86 |
+
# add text
|
| 87 |
+
font = cv2.FONT_HERSHEY_COMPLEX
|
| 88 |
+
cv2.putText(img, write_txt, txt_pos, font, font_size, color, font_thickness, cv2.LINE_AA)
|
| 89 |
+
cv2.imwrite(os.path.join(save_folder, base_name + '_text.png'), img)
|
| 90 |
+
if make_video:
|
| 91 |
+
idx += 1
|
| 92 |
+
cv2.imwrite(os.path.join(tmp_video_folder, '{:05d}.png'.format(idx)), img)
|
| 93 |
+
|
| 94 |
+
img = np.ascontiguousarray(img[:, :, [2, 1, 0]])
|
| 95 |
+
img_list.append(img)
|
| 96 |
+
|
| 97 |
+
if use_imageio:
|
| 98 |
+
import imageio
|
| 99 |
+
imageio.mimsave(os.path.join(save_folder, 'out.gif'), img_list, format='GIF', duration=duration)
|
| 100 |
+
|
| 101 |
+
if make_video:
|
| 102 |
+
os.system('ffmpeg -r {:f} -i {:s}/%05d.png -vcodec mpeg4 -y {:s}/movie.mp4'.format(
|
| 103 |
+
1 / duration, tmp_video_folder, save_folder))
|
| 104 |
+
|
| 105 |
+
if os.path.exists(tmp_video_folder):
|
| 106 |
+
os.system('rm -rf {}'.format(tmp_video_folder))
|
esrgan_plus/codes/scripts/net_interp.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
|
| 4 |
+
alpha = 0.8
|
| 5 |
+
|
| 6 |
+
net_PSNR_path = './models/RRDB_PSNR_x4.pth'
|
| 7 |
+
net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth'
|
| 8 |
+
net_interp_path = './models/interp_{:02d}.pth'.format(int(alpha*10))
|
| 9 |
+
|
| 10 |
+
net_PSNR = torch.load(net_PSNR_path)
|
| 11 |
+
net_ESRGAN = torch.load(net_ESRGAN_path)
|
| 12 |
+
net_interp = OrderedDict()
|
| 13 |
+
|
| 14 |
+
print('Interpolating with alpha = ', alpha)
|
| 15 |
+
|
| 16 |
+
for k, v_PSNR in net_PSNR.items():
|
| 17 |
+
v_ESRGAN = net_ESRGAN[k]
|
| 18 |
+
net_interp[k] = (1 - alpha) * v_PSNR + alpha * v_ESRGAN
|
| 19 |
+
|
| 20 |
+
torch.save(net_interp, net_interp_path)
|
esrgan_plus/codes/scripts/rename.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import os.path
|
| 3 |
+
import glob
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
input_folder = '/home/xtwang/Projects/PIRM18/results/pirm_selfval_img06/*' # glob matching pattern
|
| 7 |
+
save_folder = '/home/xtwang/Projects/PIRM18/results/pirm_selfval_img'
|
| 8 |
+
|
| 9 |
+
mode = 'cp' # 'cp' | 'mv'
|
| 10 |
+
file_list = sorted(glob.glob(input_folder))
|
| 11 |
+
|
| 12 |
+
if not os.path.exists(save_folder):
|
| 13 |
+
os.makedirs(save_folder)
|
| 14 |
+
print('mkdir ... ' + save_folder)
|
| 15 |
+
else:
|
| 16 |
+
print('File [{}] already exists. Exit.'.format(save_folder))
|
| 17 |
+
|
| 18 |
+
for i, path in enumerate(file_list):
|
| 19 |
+
base_name = os.path.splitext(os.path.basename(path))[0]
|
| 20 |
+
|
| 21 |
+
new_name = base_name.split('_')[0]
|
| 22 |
+
new_path = os.path.join(save_folder, new_name + '.png')
|
| 23 |
+
|
| 24 |
+
os.system(mode + ' ' + path + ' ' + new_path)
|
| 25 |
+
print(i, base_name)
|