Spaces:
Sleeping
Sleeping
update repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitattributes +2 -0
- LICENSE.txt +97 -0
- dnnlib/__init__.py +9 -0
- dnnlib/__pycache__/__init__.cpython-38.pyc +0 -0
- dnnlib/__pycache__/util.cpython-38.pyc +0 -0
- dnnlib/util.py +473 -0
- encoder4editing/LICENSE +21 -0
- encoder4editing/configs/__init__.py +0 -0
- encoder4editing/configs/data_configs.py +41 -0
- encoder4editing/configs/paths_config.py +28 -0
- encoder4editing/configs/transforms_config.py +62 -0
- encoder4editing/criteria/__init__.py +0 -0
- encoder4editing/criteria/id_loss.py +47 -0
- encoder4editing/criteria/lpips/__init__.py +0 -0
- encoder4editing/criteria/lpips/lpips.py +35 -0
- encoder4editing/criteria/lpips/networks.py +96 -0
- encoder4editing/criteria/lpips/utils.py +30 -0
- encoder4editing/criteria/moco_loss.py +71 -0
- encoder4editing/criteria/w_norm.py +14 -0
- encoder4editing/datasets/__init__.py +0 -0
- encoder4editing/datasets/gt_res_dataset.py +32 -0
- encoder4editing/datasets/images_dataset.py +33 -0
- encoder4editing/datasets/inference_dataset.py +25 -0
- encoder4editing/editings/ganspace.py +22 -0
- encoder4editing/editings/ganspace_pca/cars_pca.pt +3 -0
- encoder4editing/editings/ganspace_pca/ffhq_pca.pt +3 -0
- encoder4editing/editings/interfacegan_directions/age.pt +3 -0
- encoder4editing/editings/interfacegan_directions/pose.pt +3 -0
- encoder4editing/editings/interfacegan_directions/smile.pt +3 -0
- encoder4editing/editings/latent_editor.py +45 -0
- encoder4editing/editings/sefa.py +46 -0
- encoder4editing/environment/e4e_env.yaml +73 -0
- encoder4editing/infer.py +134 -0
- encoder4editing/metrics/LEC.py +134 -0
- encoder4editing/models/__init__.py +0 -0
- encoder4editing/models/discriminator.py +20 -0
- encoder4editing/models/encoders/__init__.py +0 -0
- encoder4editing/models/encoders/helpers.py +140 -0
- encoder4editing/models/encoders/model_irse.py +84 -0
- encoder4editing/models/encoders/psp_encoders.py +235 -0
- encoder4editing/models/latent_codes_pool.py +55 -0
- encoder4editing/models/psp.py +100 -0
- encoder4editing/models/stylegan2/__init__.py +0 -0
- encoder4editing/models/stylegan2/model.py +673 -0
- encoder4editing/models/stylegan2/op/__init__.py +2 -0
- encoder4editing/models/stylegan2/op/fused_act.py +85 -0
- encoder4editing/models/stylegan2/op/fused_bias_act.cpp +21 -0
- encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu +99 -0
- encoder4editing/models/stylegan2/op/upfirdn2d.cpp +23 -0
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.pth* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA)
|
5 |
+
|
6 |
+
|
7 |
+
=======================================================================
|
8 |
+
|
9 |
+
1. Definitions
|
10 |
+
|
11 |
+
"Licensor" means any person or entity that distributes its Work.
|
12 |
+
|
13 |
+
"Software" means the original work of authorship made available under
|
14 |
+
this License.
|
15 |
+
|
16 |
+
"Work" means the Software and any additions to or derivative works of
|
17 |
+
the Software that are made available under this License.
|
18 |
+
|
19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
21 |
+
provided, however, that for the purposes of this License, derivative
|
22 |
+
works shall not include works that remain separable from, or merely
|
23 |
+
link (or bind by name) to the interfaces of, the Work.
|
24 |
+
|
25 |
+
Works, including the Software, are "made available" under this License
|
26 |
+
by including in or with the Work either (a) a copyright notice
|
27 |
+
referencing the applicability of this License to the Work, or (b) a
|
28 |
+
copy of this License.
|
29 |
+
|
30 |
+
2. License Grants
|
31 |
+
|
32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
35 |
+
prepare derivative works of, publicly display, publicly perform,
|
36 |
+
sublicense and distribute its Work and any resulting derivative
|
37 |
+
works in any form.
|
38 |
+
|
39 |
+
3. Limitations
|
40 |
+
|
41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
42 |
+
if (a) you do so under this License, (b) you include a complete
|
43 |
+
copy of this License with your distribution, and (c) you retain
|
44 |
+
without modification any copyright, patent, trademark, or
|
45 |
+
attribution notices that are present in the Work.
|
46 |
+
|
47 |
+
3.2 Derivative Works. You may specify that additional or different
|
48 |
+
terms apply to the use, reproduction, and distribution of your
|
49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
50 |
+
provide that the use limitation in Section 3.3 applies to your
|
51 |
+
derivative works, and (b) you identify the specific derivative
|
52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
53 |
+
this License (including the redistribution requirements in Section
|
54 |
+
3.1) will continue to apply to the Work itself.
|
55 |
+
|
56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
59 |
+
derivative works commercially. As used herein, "non-commercially"
|
60 |
+
means for research or evaluation purposes only.
|
61 |
+
|
62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
63 |
+
against any Licensor (including any claim, cross-claim or
|
64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
65 |
+
are infringed by any Work, then your rights under this License from
|
66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
67 |
+
immediately.
|
68 |
+
|
69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
71 |
+
as necessary to reproduce the notices described in this License.
|
72 |
+
|
73 |
+
3.6 Termination. If you violate any term of this License, then your
|
74 |
+
rights under this License (including the grant in Section 2.1) will
|
75 |
+
terminate immediately.
|
76 |
+
|
77 |
+
4. Disclaimer of Warranty.
|
78 |
+
|
79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
83 |
+
THIS LICENSE.
|
84 |
+
|
85 |
+
5. Limitation of Liability.
|
86 |
+
|
87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
96 |
+
|
97 |
+
=======================================================================
|
dnnlib/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
from .util import EasyDict, make_cache_dir_path
|
dnnlib/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (206 Bytes). View file
|
|
dnnlib/__pycache__/util.cpython-38.pyc
ADDED
Binary file (13.7 kB). View file
|
|
dnnlib/util.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
|
11 |
+
import ctypes
|
12 |
+
import fnmatch
|
13 |
+
import importlib
|
14 |
+
import inspect
|
15 |
+
import numpy as np
|
16 |
+
import os
|
17 |
+
import shutil
|
18 |
+
import sys
|
19 |
+
import types
|
20 |
+
import io
|
21 |
+
import pickle
|
22 |
+
import re
|
23 |
+
import requests
|
24 |
+
import html
|
25 |
+
import hashlib
|
26 |
+
import glob
|
27 |
+
import tempfile
|
28 |
+
import urllib
|
29 |
+
import urllib.request
|
30 |
+
import uuid
|
31 |
+
|
32 |
+
from distutils.util import strtobool
|
33 |
+
from typing import Any, List, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
class EasyDict(dict):
|
37 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
38 |
+
|
39 |
+
def __getattr__(self, name: str) -> Any:
|
40 |
+
try:
|
41 |
+
return self[name]
|
42 |
+
except KeyError:
|
43 |
+
raise AttributeError(name)
|
44 |
+
|
45 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
46 |
+
self[name] = value
|
47 |
+
|
48 |
+
def __delattr__(self, name: str) -> None:
|
49 |
+
del self[name]
|
50 |
+
|
51 |
+
|
52 |
+
class Logger(object):
|
53 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
54 |
+
|
55 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
56 |
+
self.file = None
|
57 |
+
|
58 |
+
if file_name is not None:
|
59 |
+
self.file = open(file_name, file_mode)
|
60 |
+
|
61 |
+
self.should_flush = should_flush
|
62 |
+
self.stdout = sys.stdout
|
63 |
+
self.stderr = sys.stderr
|
64 |
+
|
65 |
+
sys.stdout = self
|
66 |
+
sys.stderr = self
|
67 |
+
|
68 |
+
def __enter__(self) -> "Logger":
|
69 |
+
return self
|
70 |
+
|
71 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
72 |
+
self.close()
|
73 |
+
|
74 |
+
def write(self, text: Union[str, bytes]) -> None:
|
75 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
76 |
+
if isinstance(text, bytes):
|
77 |
+
text = text.decode()
|
78 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
79 |
+
return
|
80 |
+
|
81 |
+
if self.file is not None:
|
82 |
+
self.file.write(text)
|
83 |
+
|
84 |
+
self.stdout.write(text)
|
85 |
+
|
86 |
+
if self.should_flush:
|
87 |
+
self.flush()
|
88 |
+
|
89 |
+
def flush(self) -> None:
|
90 |
+
"""Flush written text to both stdout and a file, if open."""
|
91 |
+
if self.file is not None:
|
92 |
+
self.file.flush()
|
93 |
+
|
94 |
+
self.stdout.flush()
|
95 |
+
|
96 |
+
def close(self) -> None:
|
97 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
98 |
+
self.flush()
|
99 |
+
|
100 |
+
# if using multiple loggers, prevent closing in wrong order
|
101 |
+
if sys.stdout is self:
|
102 |
+
sys.stdout = self.stdout
|
103 |
+
if sys.stderr is self:
|
104 |
+
sys.stderr = self.stderr
|
105 |
+
|
106 |
+
if self.file is not None:
|
107 |
+
self.file.close()
|
108 |
+
self.file = None
|
109 |
+
|
110 |
+
|
111 |
+
# Cache directories
|
112 |
+
# ------------------------------------------------------------------------------------------
|
113 |
+
|
114 |
+
_dnnlib_cache_dir = None
|
115 |
+
|
116 |
+
def set_cache_dir(path: str) -> None:
|
117 |
+
global _dnnlib_cache_dir
|
118 |
+
_dnnlib_cache_dir = path
|
119 |
+
|
120 |
+
def make_cache_dir_path(*paths: str) -> str:
|
121 |
+
if _dnnlib_cache_dir is not None:
|
122 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
123 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
124 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
125 |
+
if 'HOME' in os.environ:
|
126 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
127 |
+
if 'USERPROFILE' in os.environ:
|
128 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
129 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
130 |
+
|
131 |
+
# Small util functions
|
132 |
+
# ------------------------------------------------------------------------------------------
|
133 |
+
|
134 |
+
|
135 |
+
def format_time(seconds: Union[int, float]) -> str:
|
136 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
137 |
+
s = int(np.rint(seconds))
|
138 |
+
|
139 |
+
if s < 60:
|
140 |
+
return "{0}s".format(s)
|
141 |
+
elif s < 60 * 60:
|
142 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
143 |
+
elif s < 24 * 60 * 60:
|
144 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
145 |
+
else:
|
146 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
147 |
+
|
148 |
+
|
149 |
+
def ask_yes_no(question: str) -> bool:
|
150 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
151 |
+
while True:
|
152 |
+
try:
|
153 |
+
print("{0} [y/n]".format(question))
|
154 |
+
return strtobool(input().lower())
|
155 |
+
except ValueError:
|
156 |
+
pass
|
157 |
+
|
158 |
+
|
159 |
+
def tuple_product(t: Tuple) -> Any:
|
160 |
+
"""Calculate the product of the tuple elements."""
|
161 |
+
result = 1
|
162 |
+
|
163 |
+
for v in t:
|
164 |
+
result *= v
|
165 |
+
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
_str_to_ctype = {
|
170 |
+
"uint8": ctypes.c_ubyte,
|
171 |
+
"uint16": ctypes.c_uint16,
|
172 |
+
"uint32": ctypes.c_uint32,
|
173 |
+
"uint64": ctypes.c_uint64,
|
174 |
+
"int8": ctypes.c_byte,
|
175 |
+
"int16": ctypes.c_int16,
|
176 |
+
"int32": ctypes.c_int32,
|
177 |
+
"int64": ctypes.c_int64,
|
178 |
+
"float32": ctypes.c_float,
|
179 |
+
"float64": ctypes.c_double
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
184 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
185 |
+
type_str = None
|
186 |
+
|
187 |
+
if isinstance(type_obj, str):
|
188 |
+
type_str = type_obj
|
189 |
+
elif hasattr(type_obj, "__name__"):
|
190 |
+
type_str = type_obj.__name__
|
191 |
+
elif hasattr(type_obj, "name"):
|
192 |
+
type_str = type_obj.name
|
193 |
+
else:
|
194 |
+
raise RuntimeError("Cannot infer type name from input")
|
195 |
+
|
196 |
+
assert type_str in _str_to_ctype.keys()
|
197 |
+
|
198 |
+
my_dtype = np.dtype(type_str)
|
199 |
+
my_ctype = _str_to_ctype[type_str]
|
200 |
+
|
201 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
202 |
+
|
203 |
+
return my_dtype, my_ctype
|
204 |
+
|
205 |
+
|
206 |
+
def is_pickleable(obj: Any) -> bool:
|
207 |
+
try:
|
208 |
+
with io.BytesIO() as stream:
|
209 |
+
pickle.dump(obj, stream)
|
210 |
+
return True
|
211 |
+
except:
|
212 |
+
return False
|
213 |
+
|
214 |
+
|
215 |
+
# Functionality to import modules/objects by name, and call functions by name
|
216 |
+
# ------------------------------------------------------------------------------------------
|
217 |
+
|
218 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
219 |
+
"""Searches for the underlying module behind the name to some python object.
|
220 |
+
Returns the module and the object name (original name with module part removed)."""
|
221 |
+
|
222 |
+
# allow convenience shorthands, substitute them by full names
|
223 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
224 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
225 |
+
|
226 |
+
# list alternatives for (module_name, local_obj_name)
|
227 |
+
parts = obj_name.split(".")
|
228 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
229 |
+
|
230 |
+
# try each alternative in turn
|
231 |
+
for module_name, local_obj_name in name_pairs:
|
232 |
+
try:
|
233 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
234 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
235 |
+
return module, local_obj_name
|
236 |
+
except:
|
237 |
+
pass
|
238 |
+
|
239 |
+
# maybe some of the modules themselves contain errors?
|
240 |
+
for module_name, _local_obj_name in name_pairs:
|
241 |
+
try:
|
242 |
+
importlib.import_module(module_name) # may raise ImportError
|
243 |
+
except ImportError:
|
244 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
245 |
+
raise
|
246 |
+
|
247 |
+
# maybe the requested attribute is missing?
|
248 |
+
for module_name, local_obj_name in name_pairs:
|
249 |
+
try:
|
250 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
251 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
252 |
+
except ImportError:
|
253 |
+
pass
|
254 |
+
|
255 |
+
# we are out of luck, but we have no idea why
|
256 |
+
raise ImportError(obj_name)
|
257 |
+
|
258 |
+
|
259 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
260 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
261 |
+
if obj_name == '':
|
262 |
+
return module
|
263 |
+
obj = module
|
264 |
+
for part in obj_name.split("."):
|
265 |
+
obj = getattr(obj, part)
|
266 |
+
return obj
|
267 |
+
|
268 |
+
|
269 |
+
def get_obj_by_name(name: str) -> Any:
|
270 |
+
"""Finds the python object with the given name."""
|
271 |
+
module, obj_name = get_module_from_obj_name(name)
|
272 |
+
return get_obj_from_module(module, obj_name)
|
273 |
+
|
274 |
+
|
275 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
276 |
+
"""Finds the python object with the given name and calls it as a function."""
|
277 |
+
assert func_name is not None
|
278 |
+
func_obj = get_obj_by_name(func_name)
|
279 |
+
assert callable(func_obj)
|
280 |
+
return func_obj(*args, **kwargs)
|
281 |
+
|
282 |
+
|
283 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
284 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
285 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
286 |
+
|
287 |
+
|
288 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
289 |
+
"""Get the directory path of the module containing the given object name."""
|
290 |
+
module, _ = get_module_from_obj_name(obj_name)
|
291 |
+
return os.path.dirname(inspect.getfile(module))
|
292 |
+
|
293 |
+
|
294 |
+
def is_top_level_function(obj: Any) -> bool:
|
295 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
296 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
297 |
+
|
298 |
+
|
299 |
+
def get_top_level_function_name(obj: Any) -> str:
|
300 |
+
"""Return the fully-qualified name of a top-level function."""
|
301 |
+
assert is_top_level_function(obj)
|
302 |
+
module = obj.__module__
|
303 |
+
if module == '__main__':
|
304 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
305 |
+
return module + "." + obj.__name__
|
306 |
+
|
307 |
+
|
308 |
+
# File system helpers
|
309 |
+
# ------------------------------------------------------------------------------------------
|
310 |
+
|
311 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
312 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
313 |
+
Returns list of tuples containing both absolute and relative paths."""
|
314 |
+
assert os.path.isdir(dir_path)
|
315 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
316 |
+
|
317 |
+
if ignores is None:
|
318 |
+
ignores = []
|
319 |
+
|
320 |
+
result = []
|
321 |
+
|
322 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
323 |
+
for ignore_ in ignores:
|
324 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
325 |
+
|
326 |
+
# dirs need to be edited in-place
|
327 |
+
for d in dirs_to_remove:
|
328 |
+
dirs.remove(d)
|
329 |
+
|
330 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
331 |
+
|
332 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
333 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
334 |
+
|
335 |
+
if add_base_to_relative:
|
336 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
337 |
+
|
338 |
+
assert len(absolute_paths) == len(relative_paths)
|
339 |
+
result += zip(absolute_paths, relative_paths)
|
340 |
+
|
341 |
+
return result
|
342 |
+
|
343 |
+
|
344 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
345 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
346 |
+
Will create all necessary directories."""
|
347 |
+
for file in files:
|
348 |
+
target_dir_name = os.path.dirname(file[1])
|
349 |
+
|
350 |
+
# will create all intermediate-level directories
|
351 |
+
if not os.path.exists(target_dir_name):
|
352 |
+
os.makedirs(target_dir_name)
|
353 |
+
|
354 |
+
shutil.copyfile(file[0], file[1])
|
355 |
+
|
356 |
+
|
357 |
+
# URL helpers
|
358 |
+
# ------------------------------------------------------------------------------------------
|
359 |
+
|
360 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
361 |
+
"""Determine whether the given object is a valid URL string."""
|
362 |
+
if not isinstance(obj, str) or not "://" in obj:
|
363 |
+
return False
|
364 |
+
if allow_file_urls and obj.startswith('file://'):
|
365 |
+
return True
|
366 |
+
try:
|
367 |
+
res = requests.compat.urlparse(obj)
|
368 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
369 |
+
return False
|
370 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
371 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
372 |
+
return False
|
373 |
+
except:
|
374 |
+
return False
|
375 |
+
return True
|
376 |
+
|
377 |
+
|
378 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
379 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
380 |
+
assert num_attempts >= 1
|
381 |
+
assert not (return_filename and (not cache))
|
382 |
+
|
383 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
384 |
+
if not re.match('^[a-z]+://', url):
|
385 |
+
return url if return_filename else open(url, "rb")
|
386 |
+
|
387 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
388 |
+
# arise on Windows:
|
389 |
+
#
|
390 |
+
# file:///c:/foo.txt
|
391 |
+
#
|
392 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
393 |
+
# invalid. Drop the forward slash for such pathnames.
|
394 |
+
#
|
395 |
+
# If you touch this code path, you should test it on both Linux and
|
396 |
+
# Windows.
|
397 |
+
#
|
398 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
399 |
+
# but that converts forward slashes to backslashes and this causes
|
400 |
+
# its own set of problems.
|
401 |
+
if url.startswith('file://'):
|
402 |
+
filename = urllib.parse.urlparse(url).path
|
403 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
404 |
+
filename = filename[1:]
|
405 |
+
return filename if return_filename else open(filename, "rb")
|
406 |
+
|
407 |
+
assert is_url(url)
|
408 |
+
|
409 |
+
# Lookup from cache.
|
410 |
+
if cache_dir is None:
|
411 |
+
cache_dir = make_cache_dir_path('downloads')
|
412 |
+
|
413 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
414 |
+
if cache:
|
415 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
416 |
+
if len(cache_files) == 1:
|
417 |
+
filename = cache_files[0]
|
418 |
+
return filename if return_filename else open(filename, "rb")
|
419 |
+
|
420 |
+
# Download.
|
421 |
+
url_name = None
|
422 |
+
url_data = None
|
423 |
+
with requests.Session() as session:
|
424 |
+
if verbose:
|
425 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
426 |
+
for attempts_left in reversed(range(num_attempts)):
|
427 |
+
try:
|
428 |
+
with session.get(url) as res:
|
429 |
+
res.raise_for_status()
|
430 |
+
if len(res.content) == 0:
|
431 |
+
raise IOError("No data received")
|
432 |
+
|
433 |
+
if len(res.content) < 8192:
|
434 |
+
content_str = res.content.decode("utf-8")
|
435 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
436 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
437 |
+
if len(links) == 1:
|
438 |
+
url = requests.compat.urljoin(url, links[0])
|
439 |
+
raise IOError("Google Drive virus checker nag")
|
440 |
+
if "Google Drive - Quota exceeded" in content_str:
|
441 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
442 |
+
|
443 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
444 |
+
url_name = match[1] if match else url
|
445 |
+
url_data = res.content
|
446 |
+
if verbose:
|
447 |
+
print(" done")
|
448 |
+
break
|
449 |
+
except KeyboardInterrupt:
|
450 |
+
raise
|
451 |
+
except:
|
452 |
+
if not attempts_left:
|
453 |
+
if verbose:
|
454 |
+
print(" failed")
|
455 |
+
raise
|
456 |
+
if verbose:
|
457 |
+
print(".", end="", flush=True)
|
458 |
+
|
459 |
+
# Save to cache.
|
460 |
+
if cache:
|
461 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
462 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
463 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
464 |
+
os.makedirs(cache_dir, exist_ok=True)
|
465 |
+
with open(temp_file, "wb") as f:
|
466 |
+
f.write(url_data)
|
467 |
+
os.replace(temp_file, cache_file) # atomic
|
468 |
+
if return_filename:
|
469 |
+
return cache_file
|
470 |
+
|
471 |
+
# Return data as file object.
|
472 |
+
assert not return_filename
|
473 |
+
return io.BytesIO(url_data)
|
encoder4editing/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 omertov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
encoder4editing/configs/__init__.py
ADDED
File without changes
|
encoder4editing/configs/data_configs.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configs import transforms_config
|
2 |
+
from configs.paths_config import dataset_paths
|
3 |
+
|
4 |
+
|
5 |
+
DATASETS = {
|
6 |
+
'ffhq_encode': {
|
7 |
+
'transforms': transforms_config.EncodeTransforms,
|
8 |
+
'train_source_root': dataset_paths['ffhq'],
|
9 |
+
'train_target_root': dataset_paths['ffhq'],
|
10 |
+
'test_source_root': dataset_paths['celeba_test'],
|
11 |
+
'test_target_root': dataset_paths['celeba_test'],
|
12 |
+
},
|
13 |
+
'cars_encode': {
|
14 |
+
'transforms': transforms_config.CarsEncodeTransforms,
|
15 |
+
'train_source_root': dataset_paths['cars_train'],
|
16 |
+
'train_target_root': dataset_paths['cars_train'],
|
17 |
+
'test_source_root': dataset_paths['cars_test'],
|
18 |
+
'test_target_root': dataset_paths['cars_test'],
|
19 |
+
},
|
20 |
+
'horse_encode': {
|
21 |
+
'transforms': transforms_config.EncodeTransforms,
|
22 |
+
'train_source_root': dataset_paths['horse_train'],
|
23 |
+
'train_target_root': dataset_paths['horse_train'],
|
24 |
+
'test_source_root': dataset_paths['horse_test'],
|
25 |
+
'test_target_root': dataset_paths['horse_test'],
|
26 |
+
},
|
27 |
+
'church_encode': {
|
28 |
+
'transforms': transforms_config.EncodeTransforms,
|
29 |
+
'train_source_root': dataset_paths['church_train'],
|
30 |
+
'train_target_root': dataset_paths['church_train'],
|
31 |
+
'test_source_root': dataset_paths['church_test'],
|
32 |
+
'test_target_root': dataset_paths['church_test'],
|
33 |
+
},
|
34 |
+
'cats_encode': {
|
35 |
+
'transforms': transforms_config.EncodeTransforms,
|
36 |
+
'train_source_root': dataset_paths['cats_train'],
|
37 |
+
'train_target_root': dataset_paths['cats_train'],
|
38 |
+
'test_source_root': dataset_paths['cats_test'],
|
39 |
+
'test_target_root': dataset_paths['cats_test'],
|
40 |
+
}
|
41 |
+
}
|
encoder4editing/configs/paths_config.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset_paths = {
|
2 |
+
# Face Datasets (In the paper: FFHQ - train, CelebAHQ - test)
|
3 |
+
'ffhq': '',
|
4 |
+
'celeba_test': '',
|
5 |
+
|
6 |
+
# Cars Dataset (In the paper: Stanford cars)
|
7 |
+
'cars_train': '',
|
8 |
+
'cars_test': '',
|
9 |
+
|
10 |
+
# Horse Dataset (In the paper: LSUN Horse)
|
11 |
+
'horse_train': '',
|
12 |
+
'horse_test': '',
|
13 |
+
|
14 |
+
# Church Dataset (In the paper: LSUN Church)
|
15 |
+
'church_train': '',
|
16 |
+
'church_test': '',
|
17 |
+
|
18 |
+
# Cats Dataset (In the paper: LSUN Cat)
|
19 |
+
'cats_train': '',
|
20 |
+
'cats_test': ''
|
21 |
+
}
|
22 |
+
|
23 |
+
model_paths = {
|
24 |
+
'stylegan_ffhq': 'pretrained_models/stylegan2-ffhq-config-f.pt',
|
25 |
+
'ir_se50': 'pretrained_models/model_ir_se50.pth',
|
26 |
+
'shape_predictor': 'pretrained_models/shape_predictor_68_face_landmarks.dat',
|
27 |
+
'moco': 'pretrained_models/moco_v2_800ep_pretrain.pth'
|
28 |
+
}
|
encoder4editing/configs/transforms_config.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
|
4 |
+
|
5 |
+
class TransformsConfig(object):
|
6 |
+
|
7 |
+
def __init__(self, opts):
|
8 |
+
self.opts = opts
|
9 |
+
|
10 |
+
@abstractmethod
|
11 |
+
def get_transforms(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class EncodeTransforms(TransformsConfig):
|
16 |
+
|
17 |
+
def __init__(self, opts):
|
18 |
+
super(EncodeTransforms, self).__init__(opts)
|
19 |
+
|
20 |
+
def get_transforms(self):
|
21 |
+
transforms_dict = {
|
22 |
+
'transform_gt_train': transforms.Compose([
|
23 |
+
transforms.Resize((256, 256)),
|
24 |
+
transforms.RandomHorizontalFlip(0.5),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
27 |
+
'transform_source': None,
|
28 |
+
'transform_test': transforms.Compose([
|
29 |
+
transforms.Resize((256, 256)),
|
30 |
+
transforms.ToTensor(),
|
31 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
32 |
+
'transform_inference': transforms.Compose([
|
33 |
+
transforms.Resize((256, 256)),
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
36 |
+
}
|
37 |
+
return transforms_dict
|
38 |
+
|
39 |
+
|
40 |
+
class CarsEncodeTransforms(TransformsConfig):
|
41 |
+
|
42 |
+
def __init__(self, opts):
|
43 |
+
super(CarsEncodeTransforms, self).__init__(opts)
|
44 |
+
|
45 |
+
def get_transforms(self):
|
46 |
+
transforms_dict = {
|
47 |
+
'transform_gt_train': transforms.Compose([
|
48 |
+
transforms.Resize((192, 256)),
|
49 |
+
transforms.RandomHorizontalFlip(0.5),
|
50 |
+
transforms.ToTensor(),
|
51 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
52 |
+
'transform_source': None,
|
53 |
+
'transform_test': transforms.Compose([
|
54 |
+
transforms.Resize((192, 256)),
|
55 |
+
transforms.ToTensor(),
|
56 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
|
57 |
+
'transform_inference': transforms.Compose([
|
58 |
+
transforms.Resize((192, 256)),
|
59 |
+
transforms.ToTensor(),
|
60 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
61 |
+
}
|
62 |
+
return transforms_dict
|
encoder4editing/criteria/__init__.py
ADDED
File without changes
|
encoder4editing/criteria/id_loss.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from configs.paths_config import model_paths
|
4 |
+
from models.encoders.model_irse import Backbone
|
5 |
+
|
6 |
+
|
7 |
+
class IDLoss(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(IDLoss, self).__init__()
|
10 |
+
print('Loading ResNet ArcFace')
|
11 |
+
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
12 |
+
self.facenet.load_state_dict(torch.load(model_paths['ir_se50']))
|
13 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
14 |
+
self.facenet.eval()
|
15 |
+
for module in [self.facenet, self.face_pool]:
|
16 |
+
for param in module.parameters():
|
17 |
+
param.requires_grad = False
|
18 |
+
|
19 |
+
def extract_feats(self, x):
|
20 |
+
x = x[:, :, 35:223, 32:220] # Crop interesting region
|
21 |
+
x = self.face_pool(x)
|
22 |
+
x_feats = self.facenet(x)
|
23 |
+
return x_feats
|
24 |
+
|
25 |
+
def forward(self, y_hat, y, x):
|
26 |
+
n_samples = x.shape[0]
|
27 |
+
x_feats = self.extract_feats(x)
|
28 |
+
y_feats = self.extract_feats(y) # Otherwise use the feature from there
|
29 |
+
y_hat_feats = self.extract_feats(y_hat)
|
30 |
+
y_feats = y_feats.detach()
|
31 |
+
loss = 0
|
32 |
+
sim_improvement = 0
|
33 |
+
id_logs = []
|
34 |
+
count = 0
|
35 |
+
for i in range(n_samples):
|
36 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
37 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
38 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
39 |
+
id_logs.append({'diff_target': float(diff_target),
|
40 |
+
'diff_input': float(diff_input),
|
41 |
+
'diff_views': float(diff_views)})
|
42 |
+
loss += 1 - diff_target
|
43 |
+
id_diff = float(diff_target) - float(diff_views)
|
44 |
+
sim_improvement += id_diff
|
45 |
+
count += 1
|
46 |
+
|
47 |
+
return loss / count, sim_improvement / count, id_logs
|
encoder4editing/criteria/lpips/__init__.py
ADDED
File without changes
|
encoder4editing/criteria/lpips/lpips.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from criteria.lpips.networks import get_network, LinLayers
|
5 |
+
from criteria.lpips.utils import get_state_dict
|
6 |
+
|
7 |
+
|
8 |
+
class LPIPS(nn.Module):
|
9 |
+
r"""Creates a criterion that measures
|
10 |
+
Learned Perceptual Image Patch Similarity (LPIPS).
|
11 |
+
Arguments:
|
12 |
+
net_type (str): the network type to compare the features:
|
13 |
+
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
|
14 |
+
version (str): the version of LPIPS. Default: 0.1.
|
15 |
+
"""
|
16 |
+
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
|
17 |
+
|
18 |
+
assert version in ['0.1'], 'v0.1 is only supported now'
|
19 |
+
|
20 |
+
super(LPIPS, self).__init__()
|
21 |
+
|
22 |
+
# pretrained network
|
23 |
+
self.net = get_network(net_type).to("cuda")
|
24 |
+
|
25 |
+
# linear layers
|
26 |
+
self.lin = LinLayers(self.net.n_channels_list).to("cuda")
|
27 |
+
self.lin.load_state_dict(get_state_dict(net_type, version))
|
28 |
+
|
29 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
30 |
+
feat_x, feat_y = self.net(x), self.net(y)
|
31 |
+
|
32 |
+
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
|
33 |
+
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
|
34 |
+
|
35 |
+
return torch.sum(torch.cat(res, 0)) / x.shape[0]
|
encoder4editing/criteria/lpips/networks.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence
|
2 |
+
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision import models
|
8 |
+
|
9 |
+
from criteria.lpips.utils import normalize_activation
|
10 |
+
|
11 |
+
|
12 |
+
def get_network(net_type: str):
|
13 |
+
if net_type == 'alex':
|
14 |
+
return AlexNet()
|
15 |
+
elif net_type == 'squeeze':
|
16 |
+
return SqueezeNet()
|
17 |
+
elif net_type == 'vgg':
|
18 |
+
return VGG16()
|
19 |
+
else:
|
20 |
+
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
|
21 |
+
|
22 |
+
|
23 |
+
class LinLayers(nn.ModuleList):
|
24 |
+
def __init__(self, n_channels_list: Sequence[int]):
|
25 |
+
super(LinLayers, self).__init__([
|
26 |
+
nn.Sequential(
|
27 |
+
nn.Identity(),
|
28 |
+
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
|
29 |
+
) for nc in n_channels_list
|
30 |
+
])
|
31 |
+
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
|
36 |
+
class BaseNet(nn.Module):
|
37 |
+
def __init__(self):
|
38 |
+
super(BaseNet, self).__init__()
|
39 |
+
|
40 |
+
# register buffer
|
41 |
+
self.register_buffer(
|
42 |
+
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
43 |
+
self.register_buffer(
|
44 |
+
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
45 |
+
|
46 |
+
def set_requires_grad(self, state: bool):
|
47 |
+
for param in chain(self.parameters(), self.buffers()):
|
48 |
+
param.requires_grad = state
|
49 |
+
|
50 |
+
def z_score(self, x: torch.Tensor):
|
51 |
+
return (x - self.mean) / self.std
|
52 |
+
|
53 |
+
def forward(self, x: torch.Tensor):
|
54 |
+
x = self.z_score(x)
|
55 |
+
|
56 |
+
output = []
|
57 |
+
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
|
58 |
+
x = layer(x)
|
59 |
+
if i in self.target_layers:
|
60 |
+
output.append(normalize_activation(x))
|
61 |
+
if len(output) == len(self.target_layers):
|
62 |
+
break
|
63 |
+
return output
|
64 |
+
|
65 |
+
|
66 |
+
class SqueezeNet(BaseNet):
|
67 |
+
def __init__(self):
|
68 |
+
super(SqueezeNet, self).__init__()
|
69 |
+
|
70 |
+
self.layers = models.squeezenet1_1(True).features
|
71 |
+
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
|
72 |
+
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
|
73 |
+
|
74 |
+
self.set_requires_grad(False)
|
75 |
+
|
76 |
+
|
77 |
+
class AlexNet(BaseNet):
|
78 |
+
def __init__(self):
|
79 |
+
super(AlexNet, self).__init__()
|
80 |
+
|
81 |
+
self.layers = models.alexnet(True).features
|
82 |
+
self.target_layers = [2, 5, 8, 10, 12]
|
83 |
+
self.n_channels_list = [64, 192, 384, 256, 256]
|
84 |
+
|
85 |
+
self.set_requires_grad(False)
|
86 |
+
|
87 |
+
|
88 |
+
class VGG16(BaseNet):
|
89 |
+
def __init__(self):
|
90 |
+
super(VGG16, self).__init__()
|
91 |
+
|
92 |
+
self.layers = models.vgg16(True).features
|
93 |
+
self.target_layers = [4, 9, 16, 23, 30]
|
94 |
+
self.n_channels_list = [64, 128, 256, 512, 512]
|
95 |
+
|
96 |
+
self.set_requires_grad(False)
|
encoder4editing/criteria/lpips/utils.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_activation(x, eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
|
8 |
+
return x / (norm_factor + eps)
|
9 |
+
|
10 |
+
|
11 |
+
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
|
12 |
+
# build url
|
13 |
+
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
|
14 |
+
+ f'master/lpips/weights/v{version}/{net_type}.pth'
|
15 |
+
|
16 |
+
# download
|
17 |
+
old_state_dict = torch.hub.load_state_dict_from_url(
|
18 |
+
url, progress=True,
|
19 |
+
map_location=None if torch.cuda.is_available() else torch.device('cpu')
|
20 |
+
)
|
21 |
+
|
22 |
+
# rename keys
|
23 |
+
new_state_dict = OrderedDict()
|
24 |
+
for key, val in old_state_dict.items():
|
25 |
+
new_key = key
|
26 |
+
new_key = new_key.replace('lin', '')
|
27 |
+
new_key = new_key.replace('model.', '')
|
28 |
+
new_state_dict[new_key] = val
|
29 |
+
|
30 |
+
return new_state_dict
|
encoder4editing/criteria/moco_loss.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from configs.paths_config import model_paths
|
6 |
+
|
7 |
+
|
8 |
+
class MocoLoss(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, opts):
|
11 |
+
super(MocoLoss, self).__init__()
|
12 |
+
print("Loading MOCO model from path: {}".format(model_paths["moco"]))
|
13 |
+
self.model = self.__load_model()
|
14 |
+
self.model.eval()
|
15 |
+
for param in self.model.parameters():
|
16 |
+
param.requires_grad = False
|
17 |
+
|
18 |
+
@staticmethod
|
19 |
+
def __load_model():
|
20 |
+
import torchvision.models as models
|
21 |
+
model = models.__dict__["resnet50"]()
|
22 |
+
# freeze all layers but the last fc
|
23 |
+
for name, param in model.named_parameters():
|
24 |
+
if name not in ['fc.weight', 'fc.bias']:
|
25 |
+
param.requires_grad = False
|
26 |
+
checkpoint = torch.load(model_paths['moco'], map_location="cpu")
|
27 |
+
state_dict = checkpoint['state_dict']
|
28 |
+
# rename moco pre-trained keys
|
29 |
+
for k in list(state_dict.keys()):
|
30 |
+
# retain only encoder_q up to before the embedding layer
|
31 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
32 |
+
# remove prefix
|
33 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
34 |
+
# delete renamed or unused k
|
35 |
+
del state_dict[k]
|
36 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
37 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
38 |
+
# remove output layer
|
39 |
+
model = nn.Sequential(*list(model.children())[:-1]).cuda()
|
40 |
+
return model
|
41 |
+
|
42 |
+
def extract_feats(self, x):
|
43 |
+
x = F.interpolate(x, size=224)
|
44 |
+
x_feats = self.model(x)
|
45 |
+
x_feats = nn.functional.normalize(x_feats, dim=1)
|
46 |
+
x_feats = x_feats.squeeze()
|
47 |
+
return x_feats
|
48 |
+
|
49 |
+
def forward(self, y_hat, y, x):
|
50 |
+
n_samples = x.shape[0]
|
51 |
+
x_feats = self.extract_feats(x)
|
52 |
+
y_feats = self.extract_feats(y)
|
53 |
+
y_hat_feats = self.extract_feats(y_hat)
|
54 |
+
y_feats = y_feats.detach()
|
55 |
+
loss = 0
|
56 |
+
sim_improvement = 0
|
57 |
+
sim_logs = []
|
58 |
+
count = 0
|
59 |
+
for i in range(n_samples):
|
60 |
+
diff_target = y_hat_feats[i].dot(y_feats[i])
|
61 |
+
diff_input = y_hat_feats[i].dot(x_feats[i])
|
62 |
+
diff_views = y_feats[i].dot(x_feats[i])
|
63 |
+
sim_logs.append({'diff_target': float(diff_target),
|
64 |
+
'diff_input': float(diff_input),
|
65 |
+
'diff_views': float(diff_views)})
|
66 |
+
loss += 1 - diff_target
|
67 |
+
sim_diff = float(diff_target) - float(diff_views)
|
68 |
+
sim_improvement += sim_diff
|
69 |
+
count += 1
|
70 |
+
|
71 |
+
return loss / count, sim_improvement / count, sim_logs
|
encoder4editing/criteria/w_norm.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class WNormLoss(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self, start_from_latent_avg=True):
|
8 |
+
super(WNormLoss, self).__init__()
|
9 |
+
self.start_from_latent_avg = start_from_latent_avg
|
10 |
+
|
11 |
+
def forward(self, latent, latent_avg=None):
|
12 |
+
if self.start_from_latent_avg:
|
13 |
+
latent = latent - latent_avg
|
14 |
+
return torch.sum(latent.norm(2, dim=(1, 2))) / latent.shape[0]
|
encoder4editing/datasets/__init__.py
ADDED
File without changes
|
encoder4editing/datasets/gt_res_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# encoding: utf-8
|
3 |
+
import os
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class GTResDataset(Dataset):
|
9 |
+
|
10 |
+
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
|
11 |
+
self.pairs = []
|
12 |
+
for f in os.listdir(root_path):
|
13 |
+
image_path = os.path.join(root_path, f)
|
14 |
+
gt_path = os.path.join(gt_dir, f)
|
15 |
+
if f.endswith(".jpg") or f.endswith(".png"):
|
16 |
+
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
|
17 |
+
self.transform = transform
|
18 |
+
self.transform_train = transform_train
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.pairs)
|
22 |
+
|
23 |
+
def __getitem__(self, index):
|
24 |
+
from_path, to_path, _ = self.pairs[index]
|
25 |
+
from_im = Image.open(from_path).convert('RGB')
|
26 |
+
to_im = Image.open(to_path).convert('RGB')
|
27 |
+
|
28 |
+
if self.transform:
|
29 |
+
to_im = self.transform(to_im)
|
30 |
+
from_im = self.transform(from_im)
|
31 |
+
|
32 |
+
return from_im, to_im
|
encoder4editing/datasets/images_dataset.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class ImagesDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, source_root, target_root, opts, target_transform=None, source_transform=None):
|
9 |
+
self.source_paths = sorted(data_utils.make_dataset(source_root))
|
10 |
+
self.target_paths = sorted(data_utils.make_dataset(target_root))
|
11 |
+
self.source_transform = source_transform
|
12 |
+
self.target_transform = target_transform
|
13 |
+
self.opts = opts
|
14 |
+
|
15 |
+
def __len__(self):
|
16 |
+
return len(self.source_paths)
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
from_path = self.source_paths[index]
|
20 |
+
from_im = Image.open(from_path)
|
21 |
+
from_im = from_im.convert('RGB')
|
22 |
+
|
23 |
+
to_path = self.target_paths[index]
|
24 |
+
to_im = Image.open(to_path).convert('RGB')
|
25 |
+
if self.target_transform:
|
26 |
+
to_im = self.target_transform(to_im)
|
27 |
+
|
28 |
+
if self.source_transform:
|
29 |
+
from_im = self.source_transform(from_im)
|
30 |
+
else:
|
31 |
+
from_im = to_im
|
32 |
+
|
33 |
+
return from_im, to_im
|
encoder4editing/datasets/inference_dataset.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import Dataset
|
2 |
+
from PIL import Image
|
3 |
+
from utils import data_utils
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceDataset(Dataset):
|
7 |
+
|
8 |
+
def __init__(self, root, opts, transform=None, preprocess=None):
|
9 |
+
self.paths = sorted(data_utils.make_dataset(root))
|
10 |
+
self.transform = transform
|
11 |
+
self.preprocess = preprocess
|
12 |
+
self.opts = opts
|
13 |
+
|
14 |
+
def __len__(self):
|
15 |
+
return len(self.paths)
|
16 |
+
|
17 |
+
def __getitem__(self, index):
|
18 |
+
from_path = self.paths[index]
|
19 |
+
if self.preprocess is not None:
|
20 |
+
from_im = self.preprocess(from_path)
|
21 |
+
else:
|
22 |
+
from_im = Image.open(from_path).convert('RGB')
|
23 |
+
if self.transform:
|
24 |
+
from_im = self.transform(from_im)
|
25 |
+
return from_im
|
encoder4editing/editings/ganspace.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def edit(latents, pca, edit_directions):
|
5 |
+
edit_latents = []
|
6 |
+
for latent in latents:
|
7 |
+
for pca_idx, start, end, strength in edit_directions:
|
8 |
+
delta = get_delta(pca, latent, pca_idx, strength)
|
9 |
+
delta_padded = torch.zeros(latent.shape).to('cuda')
|
10 |
+
delta_padded[start:end] += delta.repeat(end - start, 1)
|
11 |
+
edit_latents.append(latent + delta_padded)
|
12 |
+
return torch.stack(edit_latents)
|
13 |
+
|
14 |
+
|
15 |
+
def get_delta(pca, latent, idx, strength):
|
16 |
+
# pca: ganspace checkpoint. latent: (16, 512) w+
|
17 |
+
w_centered = latent - pca['mean'].to('cuda')
|
18 |
+
lat_comp = pca['comp'].to('cuda')
|
19 |
+
lat_std = pca['std'].to('cuda')
|
20 |
+
w_coord = torch.sum(w_centered[0].reshape(-1)*lat_comp[idx].reshape(-1)) / lat_std[idx]
|
21 |
+
delta = (strength - w_coord)*lat_comp[idx]*lat_std[idx]
|
22 |
+
return delta
|
encoder4editing/editings/ganspace_pca/cars_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5c3bae61ecd85de077fbbf103f5f30cf4b7676fe23a8508166eaf2ce73c8392
|
3 |
+
size 167562
|
encoder4editing/editings/ganspace_pca/ffhq_pca.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d7f9df1c96180d9026b9cb8d04753579fbf385f321a9d0e263641601c5e5d36
|
3 |
+
size 167562
|
encoder4editing/editings/interfacegan_directions/age.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50074516b1629707d89b5e43d6b8abd1792212fa3b961a87a11323d6a5222ae0
|
3 |
+
size 2808
|
encoder4editing/editings/interfacegan_directions/pose.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:736e0eacc8488fa0b020a2e7bd256b957284c364191dfea693705e5d06d43e7d
|
3 |
+
size 37624
|
encoder4editing/editings/interfacegan_directions/smile.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:817a7e732b59dee9eba862bec8bd7e8373568443bc9f9731a21cf9b0356f0653
|
3 |
+
size 2808
|
encoder4editing/editings/latent_editor.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import sys
|
3 |
+
sys.path.append(".")
|
4 |
+
sys.path.append("..")
|
5 |
+
from editings import ganspace, sefa
|
6 |
+
from utils.common import tensor2im
|
7 |
+
|
8 |
+
|
9 |
+
class LatentEditor(object):
|
10 |
+
def __init__(self, stylegan_generator, is_cars=False):
|
11 |
+
self.generator = stylegan_generator
|
12 |
+
self.is_cars = is_cars # Since the cars StyleGAN output is 384x512, there is a need to crop the 512x512 output.
|
13 |
+
|
14 |
+
def apply_ganspace(self, latent, ganspace_pca, edit_directions):
|
15 |
+
edit_latents = ganspace.edit(latent, ganspace_pca, edit_directions)
|
16 |
+
return self._latents_to_image(edit_latents)
|
17 |
+
|
18 |
+
def apply_interfacegan(self, latent, direction, factor=1, factor_range=None):
|
19 |
+
edit_latents = []
|
20 |
+
if factor_range is not None: # Apply a range of editing factors. for example, (-5, 5)
|
21 |
+
for f in range(*factor_range):
|
22 |
+
edit_latent = latent + f * direction
|
23 |
+
edit_latents.append(edit_latent)
|
24 |
+
edit_latents = torch.cat(edit_latents)
|
25 |
+
else:
|
26 |
+
edit_latents = latent + factor * direction
|
27 |
+
return self._latents_to_image(edit_latents)
|
28 |
+
|
29 |
+
def apply_sefa(self, latent, indices=[2, 3, 4, 5], **kwargs):
|
30 |
+
edit_latents = sefa.edit(self.generator, latent, indices, **kwargs)
|
31 |
+
return self._latents_to_image(edit_latents)
|
32 |
+
|
33 |
+
# Currently, in order to apply StyleFlow editings, one should run inference,
|
34 |
+
# save the latent codes and load them form the official StyleFlow repository.
|
35 |
+
# def apply_styleflow(self):
|
36 |
+
# pass
|
37 |
+
|
38 |
+
def _latents_to_image(self, latents):
|
39 |
+
with torch.no_grad():
|
40 |
+
images, _ = self.generator([latents], randomize_noise=False, input_is_latent=True)
|
41 |
+
if self.is_cars:
|
42 |
+
images = images[:, :, 64:448, :] # 512x512 -> 384x512
|
43 |
+
horizontal_concat_image = torch.cat(list(images), 2)
|
44 |
+
final_image = tensor2im(horizontal_concat_image)
|
45 |
+
return final_image
|
encoder4editing/editings/sefa.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
|
6 |
+
def edit(generator, latents, indices, semantics=1, start_distance=-15.0, end_distance=15.0, num_samples=1, step=11):
|
7 |
+
|
8 |
+
layers, boundaries, values = factorize_weight(generator, indices)
|
9 |
+
codes = latents.detach().cpu().numpy() # (1,18,512)
|
10 |
+
|
11 |
+
# Generate visualization pages.
|
12 |
+
distances = np.linspace(start_distance, end_distance, step)
|
13 |
+
num_sam = num_samples
|
14 |
+
num_sem = semantics
|
15 |
+
|
16 |
+
edited_latents = []
|
17 |
+
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
|
18 |
+
boundary = boundaries[sem_id:sem_id + 1]
|
19 |
+
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
|
20 |
+
code = codes[sam_id:sam_id + 1]
|
21 |
+
for col_id, d in enumerate(distances, start=1):
|
22 |
+
temp_code = code.copy()
|
23 |
+
temp_code[:, layers, :] += boundary * d
|
24 |
+
edited_latents.append(torch.from_numpy(temp_code).float().cuda())
|
25 |
+
return torch.cat(edited_latents)
|
26 |
+
|
27 |
+
|
28 |
+
def factorize_weight(g_ema, layers='all'):
|
29 |
+
|
30 |
+
weights = []
|
31 |
+
if layers == 'all' or 0 in layers:
|
32 |
+
weight = g_ema.conv1.conv.modulation.weight.T
|
33 |
+
weights.append(weight.cpu().detach().numpy())
|
34 |
+
|
35 |
+
if layers == 'all':
|
36 |
+
layers = list(range(g_ema.num_layers - 1))
|
37 |
+
else:
|
38 |
+
layers = [l - 1 for l in layers if l != 0]
|
39 |
+
|
40 |
+
for idx in layers:
|
41 |
+
weight = g_ema.convs[idx].conv.modulation.weight.T
|
42 |
+
weights.append(weight.cpu().detach().numpy())
|
43 |
+
weight = np.concatenate(weights, axis=1).astype(np.float32)
|
44 |
+
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
|
45 |
+
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
|
46 |
+
return layers, eigen_vectors.T, eigen_values
|
encoder4editing/environment/e4e_env.yaml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: e4e_env
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- ca-certificates=2020.4.5.1=hecc5488_0
|
8 |
+
- certifi=2020.4.5.1=py36h9f0ad1d_0
|
9 |
+
- libedit=3.1.20181209=hc058e9b_0
|
10 |
+
- libffi=3.2.1=hd88cf55_4
|
11 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
12 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
13 |
+
- ncurses=6.2=he6710b0_1
|
14 |
+
- ninja=1.10.0=hc9558a2_0
|
15 |
+
- openssl=1.1.1g=h516909a_0
|
16 |
+
- pip=20.0.2=py36_3
|
17 |
+
- python=3.6.7=h0371630_0
|
18 |
+
- python_abi=3.6=1_cp36m
|
19 |
+
- readline=7.0=h7b6447c_5
|
20 |
+
- setuptools=46.4.0=py36_0
|
21 |
+
- sqlite=3.31.1=h62c20be_1
|
22 |
+
- tk=8.6.8=hbc83047_0
|
23 |
+
- wheel=0.34.2=py36_0
|
24 |
+
- xz=5.2.5=h7b6447c_0
|
25 |
+
- zlib=1.2.11=h7b6447c_3
|
26 |
+
- pip:
|
27 |
+
- absl-py==0.9.0
|
28 |
+
- cachetools==4.1.0
|
29 |
+
- chardet==3.0.4
|
30 |
+
- cycler==0.10.0
|
31 |
+
- decorator==4.4.2
|
32 |
+
- future==0.18.2
|
33 |
+
- google-auth==1.15.0
|
34 |
+
- google-auth-oauthlib==0.4.1
|
35 |
+
- grpcio==1.29.0
|
36 |
+
- idna==2.9
|
37 |
+
- imageio==2.8.0
|
38 |
+
- importlib-metadata==1.6.0
|
39 |
+
- kiwisolver==1.2.0
|
40 |
+
- markdown==3.2.2
|
41 |
+
- matplotlib==3.2.1
|
42 |
+
- mxnet==1.6.0
|
43 |
+
- networkx==2.4
|
44 |
+
- numpy==1.18.4
|
45 |
+
- oauthlib==3.1.0
|
46 |
+
- opencv-python==4.2.0.34
|
47 |
+
- pillow==7.1.2
|
48 |
+
- protobuf==3.12.1
|
49 |
+
- pyasn1==0.4.8
|
50 |
+
- pyasn1-modules==0.2.8
|
51 |
+
- pyparsing==2.4.7
|
52 |
+
- python-dateutil==2.8.1
|
53 |
+
- pytorch-lightning==0.7.1
|
54 |
+
- pywavelets==1.1.1
|
55 |
+
- requests==2.23.0
|
56 |
+
- requests-oauthlib==1.3.0
|
57 |
+
- rsa==4.0
|
58 |
+
- scikit-image==0.17.2
|
59 |
+
- scipy==1.4.1
|
60 |
+
- six==1.15.0
|
61 |
+
- tensorboard==2.2.1
|
62 |
+
- tensorboard-plugin-wit==1.6.0.post3
|
63 |
+
- tensorboardx==1.9
|
64 |
+
- tifffile==2020.5.25
|
65 |
+
- torch==1.6.0
|
66 |
+
- torchvision==0.7.1
|
67 |
+
- tqdm==4.46.0
|
68 |
+
- urllib3==1.25.9
|
69 |
+
- werkzeug==1.0.1
|
70 |
+
- zipp==3.1.0
|
71 |
+
- pyaml
|
72 |
+
prefix: ~/anaconda3/envs/e4e_env
|
73 |
+
|
encoder4editing/infer.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from argparse import Namespace
|
4 |
+
import time
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
|
12 |
+
sys.path.append(".")
|
13 |
+
sys.path.append("..")
|
14 |
+
|
15 |
+
from utils.common import tensor2im
|
16 |
+
from models.psp import pSp # we use the pSp framework to load the e4e encoder.
|
17 |
+
experiment_type = 'ffhq_encode'
|
18 |
+
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument('--input_image', type=str, default="", help='input image path')
|
21 |
+
args = parser.parse_args()
|
22 |
+
opts = vars(args)
|
23 |
+
print(opts)
|
24 |
+
image_path = opts["input_image"]
|
25 |
+
|
26 |
+
def get_download_model_command(file_id, file_name):
|
27 |
+
""" Get wget download command for downloading the desired model and save to directory pretrained_models. """
|
28 |
+
current_directory = os.getcwd()
|
29 |
+
save_path = "encoder4editing/saves"
|
30 |
+
if not os.path.exists(save_path):
|
31 |
+
os.makedirs(save_path)
|
32 |
+
url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
|
33 |
+
return url
|
34 |
+
|
35 |
+
MODEL_PATHS = {
|
36 |
+
"ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"},
|
37 |
+
"cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"},
|
38 |
+
"horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"},
|
39 |
+
"church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"}
|
40 |
+
}
|
41 |
+
|
42 |
+
path = MODEL_PATHS[experiment_type]
|
43 |
+
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
|
44 |
+
|
45 |
+
EXPERIMENT_DATA_ARGS = {
|
46 |
+
"ffhq_encode": {
|
47 |
+
"model_path": "encoder4editing/e4e_ffhq_encode.pt",
|
48 |
+
"image_path": "notebooks/images/input_img.jpg"
|
49 |
+
},
|
50 |
+
"cars_encode": {
|
51 |
+
"model_path": "pretrained_models/e4e_cars_encode.pt",
|
52 |
+
"image_path": "notebooks/images/car_img.jpg"
|
53 |
+
},
|
54 |
+
"horse_encode": {
|
55 |
+
"model_path": "pretrained_models/e4e_horse_encode.pt",
|
56 |
+
"image_path": "notebooks/images/horse_img.jpg"
|
57 |
+
},
|
58 |
+
"church_encode": {
|
59 |
+
"model_path": "pretrained_models/e4e_church_encode.pt",
|
60 |
+
"image_path": "notebooks/images/church_img.jpg"
|
61 |
+
}
|
62 |
+
|
63 |
+
}
|
64 |
+
# Setup required image transformations
|
65 |
+
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
|
66 |
+
if experiment_type == 'cars_encode':
|
67 |
+
EXPERIMENT_ARGS['transform'] = transforms.Compose([
|
68 |
+
transforms.Resize((192, 256)),
|
69 |
+
transforms.ToTensor(),
|
70 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
71 |
+
resize_dims = (256, 192)
|
72 |
+
else:
|
73 |
+
EXPERIMENT_ARGS['transform'] = transforms.Compose([
|
74 |
+
transforms.Resize((256, 256)),
|
75 |
+
transforms.ToTensor(),
|
76 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
|
77 |
+
resize_dims = (256, 256)
|
78 |
+
|
79 |
+
|
80 |
+
model_path = EXPERIMENT_ARGS['model_path']
|
81 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
82 |
+
opts = ckpt['opts']
|
83 |
+
|
84 |
+
# update the training options
|
85 |
+
opts['checkpoint_path'] = model_path
|
86 |
+
opts= Namespace(**opts)
|
87 |
+
net = pSp(opts)
|
88 |
+
net.eval()
|
89 |
+
net.cuda()
|
90 |
+
print('Model successfully loaded!')
|
91 |
+
|
92 |
+
|
93 |
+
original_image = Image.open(image_path)
|
94 |
+
original_image = original_image.convert("RGB")
|
95 |
+
|
96 |
+
def run_alignment(image_path):
|
97 |
+
import dlib
|
98 |
+
from utils.alignment import align_face
|
99 |
+
predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat")
|
100 |
+
aligned_image = align_face(filepath=image_path, predictor=predictor)
|
101 |
+
print("Aligned image has shape: {}".format(aligned_image.size))
|
102 |
+
return aligned_image
|
103 |
+
|
104 |
+
if experiment_type == "ffhq_encode":
|
105 |
+
input_image = run_alignment(image_path)
|
106 |
+
else:
|
107 |
+
input_image = original_image
|
108 |
+
|
109 |
+
input_image.resize(resize_dims)
|
110 |
+
|
111 |
+
img_transforms = EXPERIMENT_ARGS['transform']
|
112 |
+
transformed_image = img_transforms(input_image)
|
113 |
+
|
114 |
+
def display_alongside_source_image(result_image, source_image):
|
115 |
+
res = np.concatenate([np.array(source_image.resize(resize_dims)),
|
116 |
+
np.array(result_image.resize(resize_dims))], axis=1)
|
117 |
+
return Image.fromarray(res)
|
118 |
+
|
119 |
+
def run_on_batch(inputs, net):
|
120 |
+
images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True)
|
121 |
+
if experiment_type == 'cars_encode':
|
122 |
+
images = images[:, :, 32:224, :]
|
123 |
+
return images, latents
|
124 |
+
|
125 |
+
with torch.no_grad():
|
126 |
+
tic = time.time()
|
127 |
+
images, latents = run_on_batch(transformed_image.unsqueeze(0), net)
|
128 |
+
result_image, latent = images[0], latents[0]
|
129 |
+
toc = time.time()
|
130 |
+
print('Inference took {:.4f} seconds.'.format(toc - tic))
|
131 |
+
|
132 |
+
# Display inversion:
|
133 |
+
display_alongside_source_image(tensor2im(result_image), input_image)
|
134 |
+
np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy())
|
encoder4editing/metrics/LEC.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import DataLoader
|
6 |
+
|
7 |
+
sys.path.append(".")
|
8 |
+
sys.path.append("..")
|
9 |
+
|
10 |
+
from configs import data_configs
|
11 |
+
from datasets.images_dataset import ImagesDataset
|
12 |
+
from utils.model_utils import setup_model
|
13 |
+
|
14 |
+
|
15 |
+
class LEC:
|
16 |
+
def __init__(self, net, is_cars=False):
|
17 |
+
"""
|
18 |
+
Latent Editing Consistency metric as proposed in the main paper.
|
19 |
+
:param net: e4e model loaded over the pSp framework.
|
20 |
+
:param is_cars: An indication as to whether or not to crop the middle of the StyleGAN's output images.
|
21 |
+
"""
|
22 |
+
self.net = net
|
23 |
+
self.is_cars = is_cars
|
24 |
+
|
25 |
+
def _encode(self, images):
|
26 |
+
"""
|
27 |
+
Encodes the given images into StyleGAN's latent space.
|
28 |
+
:param images: Tensor of shape NxCxHxW representing the images to be encoded.
|
29 |
+
:return: Tensor of shape NxKx512 representing the latent space embeddings of the given image (in W(K, *) space).
|
30 |
+
"""
|
31 |
+
codes = self.net.encoder(images)
|
32 |
+
assert codes.ndim == 3, f"Invalid latent codes shape, should be NxKx512 but is {codes.shape}"
|
33 |
+
# normalize with respect to the center of an average face
|
34 |
+
if self.net.opts.start_from_latent_avg:
|
35 |
+
codes = codes + self.net.latent_avg.repeat(codes.shape[0], 1, 1)
|
36 |
+
return codes
|
37 |
+
|
38 |
+
def _generate(self, codes):
|
39 |
+
"""
|
40 |
+
Generate the StyleGAN2 images of the given codes
|
41 |
+
:param codes: Tensor of shape NxKx512 representing the StyleGAN's latent codes (in W(K, *) space).
|
42 |
+
:return: Tensor of shape NxCxHxW representing the generated images.
|
43 |
+
"""
|
44 |
+
images, _ = self.net.decoder([codes], input_is_latent=True, randomize_noise=False, return_latents=True)
|
45 |
+
images = self.net.face_pool(images)
|
46 |
+
if self.is_cars:
|
47 |
+
images = images[:, :, 32:224, :]
|
48 |
+
return images
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def _filter_outliers(arr):
|
52 |
+
arr = np.array(arr)
|
53 |
+
|
54 |
+
lo = np.percentile(arr, 1, interpolation="lower")
|
55 |
+
hi = np.percentile(arr, 99, interpolation="higher")
|
56 |
+
return np.extract(
|
57 |
+
np.logical_and(lo <= arr, arr <= hi), arr
|
58 |
+
)
|
59 |
+
|
60 |
+
def calculate_metric(self, data_loader, edit_function, inverse_edit_function):
|
61 |
+
"""
|
62 |
+
Calculate the LEC metric score.
|
63 |
+
:param data_loader: An iterable that returns a tuple of (images, _), similar to the training data loader.
|
64 |
+
:param edit_function: A function that receives latent codes and performs a semantically meaningful edit in the
|
65 |
+
latent space.
|
66 |
+
:param inverse_edit_function: A function that receives latent codes and performs the inverse edit of the
|
67 |
+
`edit_function` parameter.
|
68 |
+
:return: The LEC metric score.
|
69 |
+
"""
|
70 |
+
distances = []
|
71 |
+
with torch.no_grad():
|
72 |
+
for batch in data_loader:
|
73 |
+
x, _ = batch
|
74 |
+
inputs = x.to(device).float()
|
75 |
+
|
76 |
+
codes = self._encode(inputs)
|
77 |
+
edited_codes = edit_function(codes)
|
78 |
+
edited_image = self._generate(edited_codes)
|
79 |
+
edited_image_inversion_codes = self._encode(edited_image)
|
80 |
+
inverse_edit_codes = inverse_edit_function(edited_image_inversion_codes)
|
81 |
+
|
82 |
+
dist = (codes - inverse_edit_codes).norm(2, dim=(1, 2)).mean()
|
83 |
+
distances.append(dist.to("cpu").numpy())
|
84 |
+
|
85 |
+
distances = self._filter_outliers(distances)
|
86 |
+
return distances.mean()
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
device = "cuda"
|
91 |
+
|
92 |
+
parser = argparse.ArgumentParser(description="LEC metric calculator")
|
93 |
+
|
94 |
+
parser.add_argument("--batch", type=int, default=8, help="batch size for the models")
|
95 |
+
parser.add_argument("--images_dir", type=str, default=None,
|
96 |
+
help="Path to the images directory on which we calculate the LEC score")
|
97 |
+
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to the model checkpoints")
|
98 |
+
|
99 |
+
args = parser.parse_args()
|
100 |
+
print(args)
|
101 |
+
|
102 |
+
net, opts = setup_model(args.ckpt, device)
|
103 |
+
dataset_args = data_configs.DATASETS[opts.dataset_type]
|
104 |
+
transforms_dict = dataset_args['transforms'](opts).get_transforms()
|
105 |
+
|
106 |
+
images_directory = dataset_args['test_source_root'] if args.images_dir is None else args.images_dir
|
107 |
+
test_dataset = ImagesDataset(source_root=images_directory,
|
108 |
+
target_root=images_directory,
|
109 |
+
source_transform=transforms_dict['transform_source'],
|
110 |
+
target_transform=transforms_dict['transform_test'],
|
111 |
+
opts=opts)
|
112 |
+
|
113 |
+
data_loader = DataLoader(test_dataset,
|
114 |
+
batch_size=args.batch,
|
115 |
+
shuffle=False,
|
116 |
+
num_workers=2,
|
117 |
+
drop_last=True)
|
118 |
+
|
119 |
+
print(f'dataset length: {len(test_dataset)}')
|
120 |
+
|
121 |
+
# In the following example, we are using an InterfaceGAN based editing to calculate the LEC metric.
|
122 |
+
# Change the provided example according to your domain and needs.
|
123 |
+
direction = torch.load('../editings/interfacegan_directions/age.pt').to(device)
|
124 |
+
|
125 |
+
def edit_func_example(codes):
|
126 |
+
return codes + 3 * direction
|
127 |
+
|
128 |
+
|
129 |
+
def inverse_edit_func_example(codes):
|
130 |
+
return codes - 3 * direction
|
131 |
+
|
132 |
+
lec = LEC(net, is_cars='car' in opts.dataset_type)
|
133 |
+
result = lec.calculate_metric(data_loader, edit_func_example, inverse_edit_func_example)
|
134 |
+
print(f"LEC: {result}")
|
encoder4editing/models/__init__.py
ADDED
File without changes
|
encoder4editing/models/discriminator.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
|
3 |
+
|
4 |
+
class LatentCodesDiscriminator(nn.Module):
|
5 |
+
def __init__(self, style_dim, n_mlp):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.style_dim = style_dim
|
9 |
+
|
10 |
+
layers = []
|
11 |
+
for i in range(n_mlp-1):
|
12 |
+
layers.append(
|
13 |
+
nn.Linear(style_dim, style_dim)
|
14 |
+
)
|
15 |
+
layers.append(nn.LeakyReLU(0.2))
|
16 |
+
layers.append(nn.Linear(512, 1))
|
17 |
+
self.mlp = nn.Sequential(*layers)
|
18 |
+
|
19 |
+
def forward(self, w):
|
20 |
+
return self.mlp(w)
|
encoder4editing/models/encoders/__init__.py
ADDED
File without changes
|
encoder4editing/models/encoders/helpers.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
5 |
+
|
6 |
+
"""
|
7 |
+
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
8 |
+
"""
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, input):
|
13 |
+
return input.view(input.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
def l2_norm(input, axis=1):
|
17 |
+
norm = torch.norm(input, 2, axis, True)
|
18 |
+
output = torch.div(input, norm)
|
19 |
+
return output
|
20 |
+
|
21 |
+
|
22 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
23 |
+
""" A named tuple describing a ResNet block. """
|
24 |
+
|
25 |
+
|
26 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
27 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
28 |
+
|
29 |
+
|
30 |
+
def get_blocks(num_layers):
|
31 |
+
if num_layers == 50:
|
32 |
+
blocks = [
|
33 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
34 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
35 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
36 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
37 |
+
]
|
38 |
+
elif num_layers == 100:
|
39 |
+
blocks = [
|
40 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
41 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
42 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
43 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
44 |
+
]
|
45 |
+
elif num_layers == 152:
|
46 |
+
blocks = [
|
47 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
48 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
49 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
50 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
51 |
+
]
|
52 |
+
else:
|
53 |
+
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
54 |
+
return blocks
|
55 |
+
|
56 |
+
|
57 |
+
class SEModule(Module):
|
58 |
+
def __init__(self, channels, reduction):
|
59 |
+
super(SEModule, self).__init__()
|
60 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
61 |
+
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
62 |
+
self.relu = ReLU(inplace=True)
|
63 |
+
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
64 |
+
self.sigmoid = Sigmoid()
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
module_input = x
|
68 |
+
x = self.avg_pool(x)
|
69 |
+
x = self.fc1(x)
|
70 |
+
x = self.relu(x)
|
71 |
+
x = self.fc2(x)
|
72 |
+
x = self.sigmoid(x)
|
73 |
+
return module_input * x
|
74 |
+
|
75 |
+
|
76 |
+
class bottleneck_IR(Module):
|
77 |
+
def __init__(self, in_channel, depth, stride):
|
78 |
+
super(bottleneck_IR, self).__init__()
|
79 |
+
if in_channel == depth:
|
80 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
81 |
+
else:
|
82 |
+
self.shortcut_layer = Sequential(
|
83 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
84 |
+
BatchNorm2d(depth)
|
85 |
+
)
|
86 |
+
self.res_layer = Sequential(
|
87 |
+
BatchNorm2d(in_channel),
|
88 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
89 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
shortcut = self.shortcut_layer(x)
|
94 |
+
res = self.res_layer(x)
|
95 |
+
return res + shortcut
|
96 |
+
|
97 |
+
|
98 |
+
class bottleneck_IR_SE(Module):
|
99 |
+
def __init__(self, in_channel, depth, stride):
|
100 |
+
super(bottleneck_IR_SE, self).__init__()
|
101 |
+
if in_channel == depth:
|
102 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
103 |
+
else:
|
104 |
+
self.shortcut_layer = Sequential(
|
105 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
106 |
+
BatchNorm2d(depth)
|
107 |
+
)
|
108 |
+
self.res_layer = Sequential(
|
109 |
+
BatchNorm2d(in_channel),
|
110 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
111 |
+
PReLU(depth),
|
112 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
113 |
+
BatchNorm2d(depth),
|
114 |
+
SEModule(depth, 16)
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
shortcut = self.shortcut_layer(x)
|
119 |
+
res = self.res_layer(x)
|
120 |
+
return res + shortcut
|
121 |
+
|
122 |
+
|
123 |
+
def _upsample_add(x, y):
|
124 |
+
"""Upsample and add two feature maps.
|
125 |
+
Args:
|
126 |
+
x: (Variable) top feature map to be upsampled.
|
127 |
+
y: (Variable) lateral feature map.
|
128 |
+
Returns:
|
129 |
+
(Variable) added feature map.
|
130 |
+
Note in PyTorch, when input size is odd, the upsampled feature map
|
131 |
+
with `F.upsample(..., scale_factor=2, mode='nearest')`
|
132 |
+
maybe not equal to the lateral feature map size.
|
133 |
+
e.g.
|
134 |
+
original input size: [N,_,15,15] ->
|
135 |
+
conv2d feature map size: [N,_,8,8] ->
|
136 |
+
upsampled feature map size: [N,_,16,16]
|
137 |
+
So we choose bilinear upsample which supports arbitrary output sizes.
|
138 |
+
"""
|
139 |
+
_, _, H, W = y.size()
|
140 |
+
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True) + y
|
encoder4editing/models/encoders/model_irse.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
2 |
+
from models.encoders.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
3 |
+
|
4 |
+
"""
|
5 |
+
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
6 |
+
"""
|
7 |
+
|
8 |
+
|
9 |
+
class Backbone(Module):
|
10 |
+
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
11 |
+
super(Backbone, self).__init__()
|
12 |
+
assert input_size in [112, 224], "input_size should be 112 or 224"
|
13 |
+
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
14 |
+
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
15 |
+
blocks = get_blocks(num_layers)
|
16 |
+
if mode == 'ir':
|
17 |
+
unit_module = bottleneck_IR
|
18 |
+
elif mode == 'ir_se':
|
19 |
+
unit_module = bottleneck_IR_SE
|
20 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
21 |
+
BatchNorm2d(64),
|
22 |
+
PReLU(64))
|
23 |
+
if input_size == 112:
|
24 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
25 |
+
Dropout(drop_ratio),
|
26 |
+
Flatten(),
|
27 |
+
Linear(512 * 7 * 7, 512),
|
28 |
+
BatchNorm1d(512, affine=affine))
|
29 |
+
else:
|
30 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
31 |
+
Dropout(drop_ratio),
|
32 |
+
Flatten(),
|
33 |
+
Linear(512 * 14 * 14, 512),
|
34 |
+
BatchNorm1d(512, affine=affine))
|
35 |
+
|
36 |
+
modules = []
|
37 |
+
for block in blocks:
|
38 |
+
for bottleneck in block:
|
39 |
+
modules.append(unit_module(bottleneck.in_channel,
|
40 |
+
bottleneck.depth,
|
41 |
+
bottleneck.stride))
|
42 |
+
self.body = Sequential(*modules)
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.input_layer(x)
|
46 |
+
x = self.body(x)
|
47 |
+
x = self.output_layer(x)
|
48 |
+
return l2_norm(x)
|
49 |
+
|
50 |
+
|
51 |
+
def IR_50(input_size):
|
52 |
+
"""Constructs a ir-50 model."""
|
53 |
+
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
54 |
+
return model
|
55 |
+
|
56 |
+
|
57 |
+
def IR_101(input_size):
|
58 |
+
"""Constructs a ir-101 model."""
|
59 |
+
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
|
63 |
+
def IR_152(input_size):
|
64 |
+
"""Constructs a ir-152 model."""
|
65 |
+
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
66 |
+
return model
|
67 |
+
|
68 |
+
|
69 |
+
def IR_SE_50(input_size):
|
70 |
+
"""Constructs a ir_se-50 model."""
|
71 |
+
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
72 |
+
return model
|
73 |
+
|
74 |
+
|
75 |
+
def IR_SE_101(input_size):
|
76 |
+
"""Constructs a ir_se-101 model."""
|
77 |
+
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
78 |
+
return model
|
79 |
+
|
80 |
+
|
81 |
+
def IR_SE_152(input_size):
|
82 |
+
"""Constructs a ir_se-152 model."""
|
83 |
+
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
84 |
+
return model
|
encoder4editing/models/encoders/psp_encoders.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Conv2d, BatchNorm2d, PReLU, Sequential, Module
|
7 |
+
|
8 |
+
from models.encoders.helpers import get_blocks, bottleneck_IR, bottleneck_IR_SE, _upsample_add
|
9 |
+
from models.stylegan2.model import EqualLinear
|
10 |
+
|
11 |
+
|
12 |
+
class ProgressiveStage(Enum):
|
13 |
+
WTraining = 0
|
14 |
+
Delta1Training = 1
|
15 |
+
Delta2Training = 2
|
16 |
+
Delta3Training = 3
|
17 |
+
Delta4Training = 4
|
18 |
+
Delta5Training = 5
|
19 |
+
Delta6Training = 6
|
20 |
+
Delta7Training = 7
|
21 |
+
Delta8Training = 8
|
22 |
+
Delta9Training = 9
|
23 |
+
Delta10Training = 10
|
24 |
+
Delta11Training = 11
|
25 |
+
Delta12Training = 12
|
26 |
+
Delta13Training = 13
|
27 |
+
Delta14Training = 14
|
28 |
+
Delta15Training = 15
|
29 |
+
Delta16Training = 16
|
30 |
+
Delta17Training = 17
|
31 |
+
Inference = 18
|
32 |
+
|
33 |
+
|
34 |
+
class GradualStyleBlock(Module):
|
35 |
+
def __init__(self, in_c, out_c, spatial):
|
36 |
+
super(GradualStyleBlock, self).__init__()
|
37 |
+
self.out_c = out_c
|
38 |
+
self.spatial = spatial
|
39 |
+
num_pools = int(np.log2(spatial))
|
40 |
+
modules = []
|
41 |
+
modules += [Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
|
42 |
+
nn.LeakyReLU()]
|
43 |
+
for i in range(num_pools - 1):
|
44 |
+
modules += [
|
45 |
+
Conv2d(out_c, out_c, kernel_size=3, stride=2, padding=1),
|
46 |
+
nn.LeakyReLU()
|
47 |
+
]
|
48 |
+
self.convs = nn.Sequential(*modules)
|
49 |
+
self.linear = EqualLinear(out_c, out_c, lr_mul=1)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.convs(x)
|
53 |
+
x = x.view(-1, self.out_c)
|
54 |
+
x = self.linear(x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
class GradualStyleEncoder(Module):
|
59 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
60 |
+
super(GradualStyleEncoder, self).__init__()
|
61 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
62 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
63 |
+
blocks = get_blocks(num_layers)
|
64 |
+
if mode == 'ir':
|
65 |
+
unit_module = bottleneck_IR
|
66 |
+
elif mode == 'ir_se':
|
67 |
+
unit_module = bottleneck_IR_SE
|
68 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
69 |
+
BatchNorm2d(64),
|
70 |
+
PReLU(64))
|
71 |
+
modules = []
|
72 |
+
for block in blocks:
|
73 |
+
for bottleneck in block:
|
74 |
+
modules.append(unit_module(bottleneck.in_channel,
|
75 |
+
bottleneck.depth,
|
76 |
+
bottleneck.stride))
|
77 |
+
self.body = Sequential(*modules)
|
78 |
+
|
79 |
+
self.styles = nn.ModuleList()
|
80 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
81 |
+
self.style_count = 2 * log_size - 2
|
82 |
+
self.coarse_ind = 3
|
83 |
+
self.middle_ind = 7
|
84 |
+
for i in range(self.style_count):
|
85 |
+
if i < self.coarse_ind:
|
86 |
+
style = GradualStyleBlock(512, 512, 16)
|
87 |
+
elif i < self.middle_ind:
|
88 |
+
style = GradualStyleBlock(512, 512, 32)
|
89 |
+
else:
|
90 |
+
style = GradualStyleBlock(512, 512, 64)
|
91 |
+
self.styles.append(style)
|
92 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
93 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.input_layer(x)
|
97 |
+
|
98 |
+
latents = []
|
99 |
+
modulelist = list(self.body._modules.values())
|
100 |
+
for i, l in enumerate(modulelist):
|
101 |
+
x = l(x)
|
102 |
+
if i == 6:
|
103 |
+
c1 = x
|
104 |
+
elif i == 20:
|
105 |
+
c2 = x
|
106 |
+
elif i == 23:
|
107 |
+
c3 = x
|
108 |
+
|
109 |
+
for j in range(self.coarse_ind):
|
110 |
+
latents.append(self.styles[j](c3))
|
111 |
+
|
112 |
+
p2 = _upsample_add(c3, self.latlayer1(c2))
|
113 |
+
for j in range(self.coarse_ind, self.middle_ind):
|
114 |
+
latents.append(self.styles[j](p2))
|
115 |
+
|
116 |
+
p1 = _upsample_add(p2, self.latlayer2(c1))
|
117 |
+
for j in range(self.middle_ind, self.style_count):
|
118 |
+
latents.append(self.styles[j](p1))
|
119 |
+
|
120 |
+
out = torch.stack(latents, dim=1)
|
121 |
+
return out
|
122 |
+
|
123 |
+
|
124 |
+
class Encoder4Editing(Module):
|
125 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
126 |
+
super(Encoder4Editing, self).__init__()
|
127 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
128 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
129 |
+
blocks = get_blocks(num_layers)
|
130 |
+
if mode == 'ir':
|
131 |
+
unit_module = bottleneck_IR
|
132 |
+
elif mode == 'ir_se':
|
133 |
+
unit_module = bottleneck_IR_SE
|
134 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
135 |
+
BatchNorm2d(64),
|
136 |
+
PReLU(64))
|
137 |
+
modules = []
|
138 |
+
for block in blocks:
|
139 |
+
for bottleneck in block:
|
140 |
+
modules.append(unit_module(bottleneck.in_channel,
|
141 |
+
bottleneck.depth,
|
142 |
+
bottleneck.stride))
|
143 |
+
self.body = Sequential(*modules)
|
144 |
+
|
145 |
+
self.styles = nn.ModuleList()
|
146 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
147 |
+
self.style_count = 2 * log_size - 2
|
148 |
+
self.coarse_ind = 3
|
149 |
+
self.middle_ind = 7
|
150 |
+
|
151 |
+
for i in range(self.style_count):
|
152 |
+
if i < self.coarse_ind:
|
153 |
+
style = GradualStyleBlock(512, 512, 16)
|
154 |
+
elif i < self.middle_ind:
|
155 |
+
style = GradualStyleBlock(512, 512, 32)
|
156 |
+
else:
|
157 |
+
style = GradualStyleBlock(512, 512, 64)
|
158 |
+
self.styles.append(style)
|
159 |
+
|
160 |
+
self.latlayer1 = nn.Conv2d(256, 512, kernel_size=1, stride=1, padding=0)
|
161 |
+
self.latlayer2 = nn.Conv2d(128, 512, kernel_size=1, stride=1, padding=0)
|
162 |
+
|
163 |
+
self.progressive_stage = ProgressiveStage.Inference
|
164 |
+
|
165 |
+
def get_deltas_starting_dimensions(self):
|
166 |
+
''' Get a list of the initial dimension of every delta from which it is applied '''
|
167 |
+
return list(range(self.style_count)) # Each dimension has a delta applied to it
|
168 |
+
|
169 |
+
def set_progressive_stage(self, new_stage: ProgressiveStage):
|
170 |
+
self.progressive_stage = new_stage
|
171 |
+
print('Changed progressive stage to: ', new_stage)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = self.input_layer(x)
|
175 |
+
|
176 |
+
modulelist = list(self.body._modules.values())
|
177 |
+
for i, l in enumerate(modulelist):
|
178 |
+
x = l(x)
|
179 |
+
if i == 6:
|
180 |
+
c1 = x
|
181 |
+
elif i == 20:
|
182 |
+
c2 = x
|
183 |
+
elif i == 23:
|
184 |
+
c3 = x
|
185 |
+
|
186 |
+
# Infer main W and duplicate it
|
187 |
+
w0 = self.styles[0](c3)
|
188 |
+
w = w0.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
189 |
+
stage = self.progressive_stage.value
|
190 |
+
features = c3
|
191 |
+
for i in range(1, min(stage + 1, self.style_count)): # Infer additional deltas
|
192 |
+
if i == self.coarse_ind:
|
193 |
+
p2 = _upsample_add(c3, self.latlayer1(c2)) # FPN's middle features
|
194 |
+
features = p2
|
195 |
+
elif i == self.middle_ind:
|
196 |
+
p1 = _upsample_add(p2, self.latlayer2(c1)) # FPN's fine features
|
197 |
+
features = p1
|
198 |
+
delta_i = self.styles[i](features)
|
199 |
+
w[:, i] += delta_i
|
200 |
+
return w
|
201 |
+
|
202 |
+
|
203 |
+
class BackboneEncoderUsingLastLayerIntoW(Module):
|
204 |
+
def __init__(self, num_layers, mode='ir', opts=None):
|
205 |
+
super(BackboneEncoderUsingLastLayerIntoW, self).__init__()
|
206 |
+
print('Using BackboneEncoderUsingLastLayerIntoW')
|
207 |
+
assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
208 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
209 |
+
blocks = get_blocks(num_layers)
|
210 |
+
if mode == 'ir':
|
211 |
+
unit_module = bottleneck_IR
|
212 |
+
elif mode == 'ir_se':
|
213 |
+
unit_module = bottleneck_IR_SE
|
214 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
215 |
+
BatchNorm2d(64),
|
216 |
+
PReLU(64))
|
217 |
+
self.output_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
|
218 |
+
self.linear = EqualLinear(512, 512, lr_mul=1)
|
219 |
+
modules = []
|
220 |
+
for block in blocks:
|
221 |
+
for bottleneck in block:
|
222 |
+
modules.append(unit_module(bottleneck.in_channel,
|
223 |
+
bottleneck.depth,
|
224 |
+
bottleneck.stride))
|
225 |
+
self.body = Sequential(*modules)
|
226 |
+
log_size = int(math.log(opts.stylegan_size, 2))
|
227 |
+
self.style_count = 2 * log_size - 2
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
x = self.input_layer(x)
|
231 |
+
x = self.body(x)
|
232 |
+
x = self.output_pool(x)
|
233 |
+
x = x.view(-1, 512)
|
234 |
+
x = self.linear(x)
|
235 |
+
return x.repeat(self.style_count, 1, 1).permute(1, 0, 2)
|
encoder4editing/models/latent_codes_pool.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class LatentCodesPool:
|
6 |
+
"""This class implements latent codes buffer that stores previously generated w latent codes.
|
7 |
+
This buffer enables us to update discriminators using a history of generated w's
|
8 |
+
rather than the ones produced by the latest encoder.
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, pool_size):
|
12 |
+
"""Initialize the ImagePool class
|
13 |
+
Parameters:
|
14 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
15 |
+
"""
|
16 |
+
self.pool_size = pool_size
|
17 |
+
if self.pool_size > 0: # create an empty pool
|
18 |
+
self.num_ws = 0
|
19 |
+
self.ws = []
|
20 |
+
|
21 |
+
def query(self, ws):
|
22 |
+
"""Return w's from the pool.
|
23 |
+
Parameters:
|
24 |
+
ws: the latest generated w's from the generator
|
25 |
+
Returns w's from the buffer.
|
26 |
+
By 50/100, the buffer will return input w's.
|
27 |
+
By 50/100, the buffer will return w's previously stored in the buffer,
|
28 |
+
and insert the current w's to the buffer.
|
29 |
+
"""
|
30 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
31 |
+
return ws
|
32 |
+
return_ws = []
|
33 |
+
for w in ws: # ws.shape: (batch, 512) or (batch, n_latent, 512)
|
34 |
+
# w = torch.unsqueeze(image.data, 0)
|
35 |
+
if w.ndim == 2:
|
36 |
+
i = random.randint(0, len(w) - 1) # apply a random latent index as a candidate
|
37 |
+
w = w[i]
|
38 |
+
self.handle_w(w, return_ws)
|
39 |
+
return_ws = torch.stack(return_ws, 0) # collect all the images and return
|
40 |
+
return return_ws
|
41 |
+
|
42 |
+
def handle_w(self, w, return_ws):
|
43 |
+
if self.num_ws < self.pool_size: # if the buffer is not full; keep inserting current codes to the buffer
|
44 |
+
self.num_ws = self.num_ws + 1
|
45 |
+
self.ws.append(w)
|
46 |
+
return_ws.append(w)
|
47 |
+
else:
|
48 |
+
p = random.uniform(0, 1)
|
49 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored latent code, and insert the current code into the buffer
|
50 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
51 |
+
tmp = self.ws[random_id].clone()
|
52 |
+
self.ws[random_id] = w
|
53 |
+
return_ws.append(tmp)
|
54 |
+
else: # by another 50% chance, the buffer will return the current image
|
55 |
+
return_ws.append(w)
|
encoder4editing/models/psp.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib
|
2 |
+
|
3 |
+
matplotlib.use('Agg')
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from models.encoders import psp_encoders
|
7 |
+
from models.stylegan2.model import Generator
|
8 |
+
from configs.paths_config import model_paths
|
9 |
+
|
10 |
+
|
11 |
+
def get_keys(d, name):
|
12 |
+
if 'state_dict' in d:
|
13 |
+
d = d['state_dict']
|
14 |
+
d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
|
15 |
+
return d_filt
|
16 |
+
|
17 |
+
|
18 |
+
class pSp(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, opts):
|
21 |
+
super(pSp, self).__init__()
|
22 |
+
self.opts = opts
|
23 |
+
# Define architecture
|
24 |
+
self.encoder = self.set_encoder()
|
25 |
+
self.decoder = Generator(opts.stylegan_size, 512, 8, channel_multiplier=2)
|
26 |
+
self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
|
27 |
+
# Load weights if needed
|
28 |
+
self.load_weights()
|
29 |
+
|
30 |
+
def set_encoder(self):
|
31 |
+
if self.opts.encoder_type == 'GradualStyleEncoder':
|
32 |
+
encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts)
|
33 |
+
elif self.opts.encoder_type == 'Encoder4Editing':
|
34 |
+
encoder = psp_encoders.Encoder4Editing(50, 'ir_se', self.opts)
|
35 |
+
elif self.opts.encoder_type == 'SingleStyleCodeEncoder':
|
36 |
+
encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts)
|
37 |
+
else:
|
38 |
+
raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type))
|
39 |
+
return encoder
|
40 |
+
|
41 |
+
def load_weights(self):
|
42 |
+
if self.opts.checkpoint_path is not None:
|
43 |
+
print('Loading e4e over the pSp framework from checkpoint: {}'.format(self.opts.checkpoint_path))
|
44 |
+
ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
|
45 |
+
self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
|
46 |
+
self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
|
47 |
+
self.__load_latent_avg(ckpt)
|
48 |
+
else:
|
49 |
+
print('Loading encoders weights from irse50!')
|
50 |
+
encoder_ckpt = torch.load(model_paths['ir_se50'])
|
51 |
+
self.encoder.load_state_dict(encoder_ckpt, strict=False)
|
52 |
+
print('Loading decoder weights from pretrained!')
|
53 |
+
ckpt = torch.load(self.opts.stylegan_weights)
|
54 |
+
self.decoder.load_state_dict(ckpt['g_ema'], strict=False)
|
55 |
+
self.__load_latent_avg(ckpt, repeat=self.encoder.style_count)
|
56 |
+
|
57 |
+
def forward(self, x, resize=True, latent_mask=None, input_code=False, randomize_noise=True,
|
58 |
+
inject_latent=None, return_latents=False, alpha=None):
|
59 |
+
if input_code:
|
60 |
+
codes = x
|
61 |
+
else:
|
62 |
+
codes = self.encoder(x)
|
63 |
+
# normalize with respect to the center of an average face
|
64 |
+
if self.opts.start_from_latent_avg:
|
65 |
+
if codes.ndim == 2:
|
66 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
|
67 |
+
else:
|
68 |
+
codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1)
|
69 |
+
|
70 |
+
if latent_mask is not None:
|
71 |
+
for i in latent_mask:
|
72 |
+
if inject_latent is not None:
|
73 |
+
if alpha is not None:
|
74 |
+
codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i]
|
75 |
+
else:
|
76 |
+
codes[:, i] = inject_latent[:, i]
|
77 |
+
else:
|
78 |
+
codes[:, i] = 0
|
79 |
+
|
80 |
+
input_is_latent = not input_code
|
81 |
+
images, result_latent = self.decoder([codes],
|
82 |
+
input_is_latent=input_is_latent,
|
83 |
+
randomize_noise=randomize_noise,
|
84 |
+
return_latents=return_latents)
|
85 |
+
|
86 |
+
if resize:
|
87 |
+
images = self.face_pool(images)
|
88 |
+
|
89 |
+
if return_latents:
|
90 |
+
return images, result_latent
|
91 |
+
else:
|
92 |
+
return images
|
93 |
+
|
94 |
+
def __load_latent_avg(self, ckpt, repeat=None):
|
95 |
+
if 'latent_avg' in ckpt:
|
96 |
+
self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
|
97 |
+
if repeat is not None:
|
98 |
+
self.latent_avg = self.latent_avg.repeat(repeat, 1)
|
99 |
+
else:
|
100 |
+
self.latent_avg = None
|
encoder4editing/models/stylegan2/__init__.py
ADDED
File without changes
|
encoder4editing/models/stylegan2/model.py
ADDED
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from models.stylegan2.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
|
8 |
+
|
9 |
+
|
10 |
+
class PixelNorm(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def forward(self, input):
|
15 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
16 |
+
|
17 |
+
|
18 |
+
def make_kernel(k):
|
19 |
+
k = torch.tensor(k, dtype=torch.float32)
|
20 |
+
|
21 |
+
if k.ndim == 1:
|
22 |
+
k = k[None, :] * k[:, None]
|
23 |
+
|
24 |
+
k /= k.sum()
|
25 |
+
|
26 |
+
return k
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample(nn.Module):
|
30 |
+
def __init__(self, kernel, factor=2):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.factor = factor
|
34 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
35 |
+
self.register_buffer('kernel', kernel)
|
36 |
+
|
37 |
+
p = kernel.shape[0] - factor
|
38 |
+
|
39 |
+
pad0 = (p + 1) // 2 + factor - 1
|
40 |
+
pad1 = p // 2
|
41 |
+
|
42 |
+
self.pad = (pad0, pad1)
|
43 |
+
|
44 |
+
def forward(self, input):
|
45 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
46 |
+
|
47 |
+
return out
|
48 |
+
|
49 |
+
|
50 |
+
class Downsample(nn.Module):
|
51 |
+
def __init__(self, kernel, factor=2):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
self.factor = factor
|
55 |
+
kernel = make_kernel(kernel)
|
56 |
+
self.register_buffer('kernel', kernel)
|
57 |
+
|
58 |
+
p = kernel.shape[0] - factor
|
59 |
+
|
60 |
+
pad0 = (p + 1) // 2
|
61 |
+
pad1 = p // 2
|
62 |
+
|
63 |
+
self.pad = (pad0, pad1)
|
64 |
+
|
65 |
+
def forward(self, input):
|
66 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
67 |
+
|
68 |
+
return out
|
69 |
+
|
70 |
+
|
71 |
+
class Blur(nn.Module):
|
72 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
kernel = make_kernel(kernel)
|
76 |
+
|
77 |
+
if upsample_factor > 1:
|
78 |
+
kernel = kernel * (upsample_factor ** 2)
|
79 |
+
|
80 |
+
self.register_buffer('kernel', kernel)
|
81 |
+
|
82 |
+
self.pad = pad
|
83 |
+
|
84 |
+
def forward(self, input):
|
85 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
86 |
+
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
class EqualConv2d(nn.Module):
|
91 |
+
def __init__(
|
92 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.weight = nn.Parameter(
|
97 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
98 |
+
)
|
99 |
+
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
100 |
+
|
101 |
+
self.stride = stride
|
102 |
+
self.padding = padding
|
103 |
+
|
104 |
+
if bias:
|
105 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
106 |
+
|
107 |
+
else:
|
108 |
+
self.bias = None
|
109 |
+
|
110 |
+
def forward(self, input):
|
111 |
+
out = F.conv2d(
|
112 |
+
input,
|
113 |
+
self.weight * self.scale,
|
114 |
+
bias=self.bias,
|
115 |
+
stride=self.stride,
|
116 |
+
padding=self.padding,
|
117 |
+
)
|
118 |
+
|
119 |
+
return out
|
120 |
+
|
121 |
+
def __repr__(self):
|
122 |
+
return (
|
123 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
124 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
class EqualLinear(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
131 |
+
):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
135 |
+
|
136 |
+
if bias:
|
137 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
138 |
+
|
139 |
+
else:
|
140 |
+
self.bias = None
|
141 |
+
|
142 |
+
self.activation = activation
|
143 |
+
|
144 |
+
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
145 |
+
self.lr_mul = lr_mul
|
146 |
+
|
147 |
+
def forward(self, input):
|
148 |
+
if self.activation:
|
149 |
+
out = F.linear(input, self.weight * self.scale)
|
150 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
151 |
+
|
152 |
+
else:
|
153 |
+
out = F.linear(
|
154 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
155 |
+
)
|
156 |
+
|
157 |
+
return out
|
158 |
+
|
159 |
+
def __repr__(self):
|
160 |
+
return (
|
161 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
class ScaledLeakyReLU(nn.Module):
|
166 |
+
def __init__(self, negative_slope=0.2):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
self.negative_slope = negative_slope
|
170 |
+
|
171 |
+
def forward(self, input):
|
172 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
173 |
+
|
174 |
+
return out * math.sqrt(2)
|
175 |
+
|
176 |
+
|
177 |
+
class ModulatedConv2d(nn.Module):
|
178 |
+
def __init__(
|
179 |
+
self,
|
180 |
+
in_channel,
|
181 |
+
out_channel,
|
182 |
+
kernel_size,
|
183 |
+
style_dim,
|
184 |
+
demodulate=True,
|
185 |
+
upsample=False,
|
186 |
+
downsample=False,
|
187 |
+
blur_kernel=[1, 3, 3, 1],
|
188 |
+
):
|
189 |
+
super().__init__()
|
190 |
+
|
191 |
+
self.eps = 1e-8
|
192 |
+
self.kernel_size = kernel_size
|
193 |
+
self.in_channel = in_channel
|
194 |
+
self.out_channel = out_channel
|
195 |
+
self.upsample = upsample
|
196 |
+
self.downsample = downsample
|
197 |
+
|
198 |
+
if upsample:
|
199 |
+
factor = 2
|
200 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
201 |
+
pad0 = (p + 1) // 2 + factor - 1
|
202 |
+
pad1 = p // 2 + 1
|
203 |
+
|
204 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
205 |
+
|
206 |
+
if downsample:
|
207 |
+
factor = 2
|
208 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
209 |
+
pad0 = (p + 1) // 2
|
210 |
+
pad1 = p // 2
|
211 |
+
|
212 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
213 |
+
|
214 |
+
fan_in = in_channel * kernel_size ** 2
|
215 |
+
self.scale = 1 / math.sqrt(fan_in)
|
216 |
+
self.padding = kernel_size // 2
|
217 |
+
|
218 |
+
self.weight = nn.Parameter(
|
219 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
220 |
+
)
|
221 |
+
|
222 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
223 |
+
|
224 |
+
self.demodulate = demodulate
|
225 |
+
|
226 |
+
def __repr__(self):
|
227 |
+
return (
|
228 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
229 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(self, input, style):
|
233 |
+
batch, in_channel, height, width = input.shape
|
234 |
+
|
235 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
236 |
+
weight = self.scale * self.weight * style
|
237 |
+
|
238 |
+
if self.demodulate:
|
239 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
240 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
241 |
+
|
242 |
+
weight = weight.view(
|
243 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
244 |
+
)
|
245 |
+
|
246 |
+
if self.upsample:
|
247 |
+
input = input.view(1, batch * in_channel, height, width)
|
248 |
+
weight = weight.view(
|
249 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
250 |
+
)
|
251 |
+
weight = weight.transpose(1, 2).reshape(
|
252 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
253 |
+
)
|
254 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
255 |
+
_, _, height, width = out.shape
|
256 |
+
out = out.view(batch, self.out_channel, height, width)
|
257 |
+
out = self.blur(out)
|
258 |
+
|
259 |
+
elif self.downsample:
|
260 |
+
input = self.blur(input)
|
261 |
+
_, _, height, width = input.shape
|
262 |
+
input = input.view(1, batch * in_channel, height, width)
|
263 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
264 |
+
_, _, height, width = out.shape
|
265 |
+
out = out.view(batch, self.out_channel, height, width)
|
266 |
+
|
267 |
+
else:
|
268 |
+
input = input.view(1, batch * in_channel, height, width)
|
269 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
270 |
+
_, _, height, width = out.shape
|
271 |
+
out = out.view(batch, self.out_channel, height, width)
|
272 |
+
|
273 |
+
return out
|
274 |
+
|
275 |
+
|
276 |
+
class NoiseInjection(nn.Module):
|
277 |
+
def __init__(self):
|
278 |
+
super().__init__()
|
279 |
+
|
280 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
281 |
+
|
282 |
+
def forward(self, image, noise=None):
|
283 |
+
if noise is None:
|
284 |
+
batch, _, height, width = image.shape
|
285 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
286 |
+
|
287 |
+
return image + self.weight * noise
|
288 |
+
|
289 |
+
|
290 |
+
class ConstantInput(nn.Module):
|
291 |
+
def __init__(self, channel, size=4):
|
292 |
+
super().__init__()
|
293 |
+
|
294 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
295 |
+
|
296 |
+
def forward(self, input):
|
297 |
+
batch = input.shape[0]
|
298 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
299 |
+
|
300 |
+
return out
|
301 |
+
|
302 |
+
|
303 |
+
class StyledConv(nn.Module):
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
in_channel,
|
307 |
+
out_channel,
|
308 |
+
kernel_size,
|
309 |
+
style_dim,
|
310 |
+
upsample=False,
|
311 |
+
blur_kernel=[1, 3, 3, 1],
|
312 |
+
demodulate=True,
|
313 |
+
):
|
314 |
+
super().__init__()
|
315 |
+
|
316 |
+
self.conv = ModulatedConv2d(
|
317 |
+
in_channel,
|
318 |
+
out_channel,
|
319 |
+
kernel_size,
|
320 |
+
style_dim,
|
321 |
+
upsample=upsample,
|
322 |
+
blur_kernel=blur_kernel,
|
323 |
+
demodulate=demodulate,
|
324 |
+
)
|
325 |
+
|
326 |
+
self.noise = NoiseInjection()
|
327 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
328 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
329 |
+
self.activate = FusedLeakyReLU(out_channel)
|
330 |
+
|
331 |
+
def forward(self, input, style, noise=None):
|
332 |
+
out = self.conv(input, style)
|
333 |
+
out = self.noise(out, noise=noise)
|
334 |
+
# out = out + self.bias
|
335 |
+
out = self.activate(out)
|
336 |
+
|
337 |
+
return out
|
338 |
+
|
339 |
+
|
340 |
+
class ToRGB(nn.Module):
|
341 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
342 |
+
super().__init__()
|
343 |
+
|
344 |
+
if upsample:
|
345 |
+
self.upsample = Upsample(blur_kernel)
|
346 |
+
|
347 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
348 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
349 |
+
|
350 |
+
def forward(self, input, style, skip=None):
|
351 |
+
out = self.conv(input, style)
|
352 |
+
out = out + self.bias
|
353 |
+
|
354 |
+
if skip is not None:
|
355 |
+
skip = self.upsample(skip)
|
356 |
+
|
357 |
+
out = out + skip
|
358 |
+
|
359 |
+
return out
|
360 |
+
|
361 |
+
|
362 |
+
class Generator(nn.Module):
|
363 |
+
def __init__(
|
364 |
+
self,
|
365 |
+
size,
|
366 |
+
style_dim,
|
367 |
+
n_mlp,
|
368 |
+
channel_multiplier=2,
|
369 |
+
blur_kernel=[1, 3, 3, 1],
|
370 |
+
lr_mlp=0.01,
|
371 |
+
):
|
372 |
+
super().__init__()
|
373 |
+
|
374 |
+
self.size = size
|
375 |
+
|
376 |
+
self.style_dim = style_dim
|
377 |
+
|
378 |
+
layers = [PixelNorm()]
|
379 |
+
|
380 |
+
for i in range(n_mlp):
|
381 |
+
layers.append(
|
382 |
+
EqualLinear(
|
383 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
384 |
+
)
|
385 |
+
)
|
386 |
+
|
387 |
+
self.style = nn.Sequential(*layers)
|
388 |
+
|
389 |
+
self.channels = {
|
390 |
+
4: 512,
|
391 |
+
8: 512,
|
392 |
+
16: 512,
|
393 |
+
32: 512,
|
394 |
+
64: 256 * channel_multiplier,
|
395 |
+
128: 128 * channel_multiplier,
|
396 |
+
256: 64 * channel_multiplier,
|
397 |
+
512: 32 * channel_multiplier,
|
398 |
+
1024: 16 * channel_multiplier,
|
399 |
+
}
|
400 |
+
|
401 |
+
self.input = ConstantInput(self.channels[4])
|
402 |
+
self.conv1 = StyledConv(
|
403 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
404 |
+
)
|
405 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
406 |
+
|
407 |
+
self.log_size = int(math.log(size, 2))
|
408 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
409 |
+
|
410 |
+
self.convs = nn.ModuleList()
|
411 |
+
self.upsamples = nn.ModuleList()
|
412 |
+
self.to_rgbs = nn.ModuleList()
|
413 |
+
self.noises = nn.Module()
|
414 |
+
|
415 |
+
in_channel = self.channels[4]
|
416 |
+
|
417 |
+
for layer_idx in range(self.num_layers):
|
418 |
+
res = (layer_idx + 5) // 2
|
419 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
420 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
421 |
+
|
422 |
+
for i in range(3, self.log_size + 1):
|
423 |
+
out_channel = self.channels[2 ** i]
|
424 |
+
|
425 |
+
self.convs.append(
|
426 |
+
StyledConv(
|
427 |
+
in_channel,
|
428 |
+
out_channel,
|
429 |
+
3,
|
430 |
+
style_dim,
|
431 |
+
upsample=True,
|
432 |
+
blur_kernel=blur_kernel,
|
433 |
+
)
|
434 |
+
)
|
435 |
+
|
436 |
+
self.convs.append(
|
437 |
+
StyledConv(
|
438 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
439 |
+
)
|
440 |
+
)
|
441 |
+
|
442 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
443 |
+
|
444 |
+
in_channel = out_channel
|
445 |
+
|
446 |
+
self.n_latent = self.log_size * 2 - 2
|
447 |
+
|
448 |
+
def make_noise(self):
|
449 |
+
device = self.input.input.device
|
450 |
+
|
451 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
452 |
+
|
453 |
+
for i in range(3, self.log_size + 1):
|
454 |
+
for _ in range(2):
|
455 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
456 |
+
|
457 |
+
return noises
|
458 |
+
|
459 |
+
def mean_latent(self, n_latent):
|
460 |
+
latent_in = torch.randn(
|
461 |
+
n_latent, self.style_dim, device=self.input.input.device
|
462 |
+
)
|
463 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
464 |
+
|
465 |
+
return latent
|
466 |
+
|
467 |
+
def get_latent(self, input):
|
468 |
+
return self.style(input)
|
469 |
+
|
470 |
+
def forward(
|
471 |
+
self,
|
472 |
+
styles,
|
473 |
+
return_latents=False,
|
474 |
+
return_features=False,
|
475 |
+
inject_index=None,
|
476 |
+
truncation=1,
|
477 |
+
truncation_latent=None,
|
478 |
+
input_is_latent=False,
|
479 |
+
noise=None,
|
480 |
+
randomize_noise=True,
|
481 |
+
):
|
482 |
+
if not input_is_latent:
|
483 |
+
styles = [self.style(s) for s in styles]
|
484 |
+
|
485 |
+
if noise is None:
|
486 |
+
if randomize_noise:
|
487 |
+
noise = [None] * self.num_layers
|
488 |
+
else:
|
489 |
+
noise = [
|
490 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
491 |
+
]
|
492 |
+
|
493 |
+
if truncation < 1:
|
494 |
+
style_t = []
|
495 |
+
|
496 |
+
for style in styles:
|
497 |
+
style_t.append(
|
498 |
+
truncation_latent + truncation * (style - truncation_latent)
|
499 |
+
)
|
500 |
+
|
501 |
+
styles = style_t
|
502 |
+
|
503 |
+
if len(styles) < 2:
|
504 |
+
inject_index = self.n_latent
|
505 |
+
|
506 |
+
if styles[0].ndim < 3:
|
507 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
508 |
+
else:
|
509 |
+
latent = styles[0]
|
510 |
+
|
511 |
+
else:
|
512 |
+
if inject_index is None:
|
513 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
514 |
+
|
515 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
516 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
517 |
+
|
518 |
+
latent = torch.cat([latent, latent2], 1)
|
519 |
+
|
520 |
+
out = self.input(latent)
|
521 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
522 |
+
|
523 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
524 |
+
|
525 |
+
i = 1
|
526 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
527 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
528 |
+
):
|
529 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
530 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
531 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
532 |
+
|
533 |
+
i += 2
|
534 |
+
|
535 |
+
image = skip
|
536 |
+
|
537 |
+
if return_latents:
|
538 |
+
return image, latent
|
539 |
+
elif return_features:
|
540 |
+
return image, out
|
541 |
+
else:
|
542 |
+
return image, None
|
543 |
+
|
544 |
+
|
545 |
+
class ConvLayer(nn.Sequential):
|
546 |
+
def __init__(
|
547 |
+
self,
|
548 |
+
in_channel,
|
549 |
+
out_channel,
|
550 |
+
kernel_size,
|
551 |
+
downsample=False,
|
552 |
+
blur_kernel=[1, 3, 3, 1],
|
553 |
+
bias=True,
|
554 |
+
activate=True,
|
555 |
+
):
|
556 |
+
layers = []
|
557 |
+
|
558 |
+
if downsample:
|
559 |
+
factor = 2
|
560 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
561 |
+
pad0 = (p + 1) // 2
|
562 |
+
pad1 = p // 2
|
563 |
+
|
564 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
565 |
+
|
566 |
+
stride = 2
|
567 |
+
self.padding = 0
|
568 |
+
|
569 |
+
else:
|
570 |
+
stride = 1
|
571 |
+
self.padding = kernel_size // 2
|
572 |
+
|
573 |
+
layers.append(
|
574 |
+
EqualConv2d(
|
575 |
+
in_channel,
|
576 |
+
out_channel,
|
577 |
+
kernel_size,
|
578 |
+
padding=self.padding,
|
579 |
+
stride=stride,
|
580 |
+
bias=bias and not activate,
|
581 |
+
)
|
582 |
+
)
|
583 |
+
|
584 |
+
if activate:
|
585 |
+
if bias:
|
586 |
+
layers.append(FusedLeakyReLU(out_channel))
|
587 |
+
|
588 |
+
else:
|
589 |
+
layers.append(ScaledLeakyReLU(0.2))
|
590 |
+
|
591 |
+
super().__init__(*layers)
|
592 |
+
|
593 |
+
|
594 |
+
class ResBlock(nn.Module):
|
595 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
596 |
+
super().__init__()
|
597 |
+
|
598 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
599 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
600 |
+
|
601 |
+
self.skip = ConvLayer(
|
602 |
+
in_channel, out_channel, 1, downsample=True, activate=False, bias=False
|
603 |
+
)
|
604 |
+
|
605 |
+
def forward(self, input):
|
606 |
+
out = self.conv1(input)
|
607 |
+
out = self.conv2(out)
|
608 |
+
|
609 |
+
skip = self.skip(input)
|
610 |
+
out = (out + skip) / math.sqrt(2)
|
611 |
+
|
612 |
+
return out
|
613 |
+
|
614 |
+
|
615 |
+
class Discriminator(nn.Module):
|
616 |
+
def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
|
617 |
+
super().__init__()
|
618 |
+
|
619 |
+
channels = {
|
620 |
+
4: 512,
|
621 |
+
8: 512,
|
622 |
+
16: 512,
|
623 |
+
32: 512,
|
624 |
+
64: 256 * channel_multiplier,
|
625 |
+
128: 128 * channel_multiplier,
|
626 |
+
256: 64 * channel_multiplier,
|
627 |
+
512: 32 * channel_multiplier,
|
628 |
+
1024: 16 * channel_multiplier,
|
629 |
+
}
|
630 |
+
|
631 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
632 |
+
|
633 |
+
log_size = int(math.log(size, 2))
|
634 |
+
|
635 |
+
in_channel = channels[size]
|
636 |
+
|
637 |
+
for i in range(log_size, 2, -1):
|
638 |
+
out_channel = channels[2 ** (i - 1)]
|
639 |
+
|
640 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
641 |
+
|
642 |
+
in_channel = out_channel
|
643 |
+
|
644 |
+
self.convs = nn.Sequential(*convs)
|
645 |
+
|
646 |
+
self.stddev_group = 4
|
647 |
+
self.stddev_feat = 1
|
648 |
+
|
649 |
+
self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
|
650 |
+
self.final_linear = nn.Sequential(
|
651 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
652 |
+
EqualLinear(channels[4], 1),
|
653 |
+
)
|
654 |
+
|
655 |
+
def forward(self, input):
|
656 |
+
out = self.convs(input)
|
657 |
+
|
658 |
+
batch, channel, height, width = out.shape
|
659 |
+
group = min(batch, self.stddev_group)
|
660 |
+
stddev = out.view(
|
661 |
+
group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
|
662 |
+
)
|
663 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
664 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
665 |
+
stddev = stddev.repeat(group, 1, height, width)
|
666 |
+
out = torch.cat([out, stddev], 1)
|
667 |
+
|
668 |
+
out = self.final_conv(out)
|
669 |
+
|
670 |
+
out = out.view(batch, -1)
|
671 |
+
out = self.final_linear(out)
|
672 |
+
|
673 |
+
return out
|
encoder4editing/models/stylegan2/op/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .fused_act import FusedLeakyReLU, fused_leaky_relu
|
2 |
+
from .upfirdn2d import upfirdn2d
|
encoder4editing/models/stylegan2/op/fused_act.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.autograd import Function
|
6 |
+
from torch.utils.cpp_extension import load
|
7 |
+
|
8 |
+
module_path = os.path.dirname(__file__)
|
9 |
+
fused = load(
|
10 |
+
'fused',
|
11 |
+
sources=[
|
12 |
+
os.path.join(module_path, 'fused_bias_act.cpp'),
|
13 |
+
os.path.join(module_path, 'fused_bias_act_kernel.cu'),
|
14 |
+
],
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class FusedLeakyReLUFunctionBackward(Function):
|
19 |
+
@staticmethod
|
20 |
+
def forward(ctx, grad_output, out, negative_slope, scale):
|
21 |
+
ctx.save_for_backward(out)
|
22 |
+
ctx.negative_slope = negative_slope
|
23 |
+
ctx.scale = scale
|
24 |
+
|
25 |
+
empty = grad_output.new_empty(0)
|
26 |
+
|
27 |
+
grad_input = fused.fused_bias_act(
|
28 |
+
grad_output, empty, out, 3, 1, negative_slope, scale
|
29 |
+
)
|
30 |
+
|
31 |
+
dim = [0]
|
32 |
+
|
33 |
+
if grad_input.ndim > 2:
|
34 |
+
dim += list(range(2, grad_input.ndim))
|
35 |
+
|
36 |
+
grad_bias = grad_input.sum(dim).detach()
|
37 |
+
|
38 |
+
return grad_input, grad_bias
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def backward(ctx, gradgrad_input, gradgrad_bias):
|
42 |
+
out, = ctx.saved_tensors
|
43 |
+
gradgrad_out = fused.fused_bias_act(
|
44 |
+
gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
|
45 |
+
)
|
46 |
+
|
47 |
+
return gradgrad_out, None, None, None
|
48 |
+
|
49 |
+
|
50 |
+
class FusedLeakyReLUFunction(Function):
|
51 |
+
@staticmethod
|
52 |
+
def forward(ctx, input, bias, negative_slope, scale):
|
53 |
+
empty = input.new_empty(0)
|
54 |
+
out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
|
55 |
+
ctx.save_for_backward(out)
|
56 |
+
ctx.negative_slope = negative_slope
|
57 |
+
ctx.scale = scale
|
58 |
+
|
59 |
+
return out
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def backward(ctx, grad_output):
|
63 |
+
out, = ctx.saved_tensors
|
64 |
+
|
65 |
+
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
|
66 |
+
grad_output, out, ctx.negative_slope, ctx.scale
|
67 |
+
)
|
68 |
+
|
69 |
+
return grad_input, grad_bias, None, None
|
70 |
+
|
71 |
+
|
72 |
+
class FusedLeakyReLU(nn.Module):
|
73 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
self.bias = nn.Parameter(torch.zeros(channel))
|
77 |
+
self.negative_slope = negative_slope
|
78 |
+
self.scale = scale
|
79 |
+
|
80 |
+
def forward(self, input):
|
81 |
+
return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
82 |
+
|
83 |
+
|
84 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
85 |
+
return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
|
encoder4editing/models/stylegan2/op/fused_bias_act.cpp
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
5 |
+
int act, int grad, float alpha, float scale);
|
6 |
+
|
7 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
8 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
9 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
10 |
+
|
11 |
+
torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
12 |
+
int act, int grad, float alpha, float scale) {
|
13 |
+
CHECK_CUDA(input);
|
14 |
+
CHECK_CUDA(bias);
|
15 |
+
|
16 |
+
return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
|
17 |
+
}
|
18 |
+
|
19 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
20 |
+
m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
|
21 |
+
}
|
encoder4editing/models/stylegan2/op/fused_bias_act_kernel.cu
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
|
2 |
+
//
|
3 |
+
// This work is made available under the Nvidia Source Code License-NC.
|
4 |
+
// To view a copy of this license, visit
|
5 |
+
// https://nvlabs.github.io/stylegan2/license.html
|
6 |
+
|
7 |
+
#include <torch/types.h>
|
8 |
+
|
9 |
+
#include <ATen/ATen.h>
|
10 |
+
#include <ATen/AccumulateType.h>
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
#include <ATen/cuda/CUDAApplyUtils.cuh>
|
13 |
+
|
14 |
+
#include <cuda.h>
|
15 |
+
#include <cuda_runtime.h>
|
16 |
+
|
17 |
+
|
18 |
+
template <typename scalar_t>
|
19 |
+
static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
|
20 |
+
int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
|
21 |
+
int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
|
22 |
+
|
23 |
+
scalar_t zero = 0.0;
|
24 |
+
|
25 |
+
for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
|
26 |
+
scalar_t x = p_x[xi];
|
27 |
+
|
28 |
+
if (use_bias) {
|
29 |
+
x += p_b[(xi / step_b) % size_b];
|
30 |
+
}
|
31 |
+
|
32 |
+
scalar_t ref = use_ref ? p_ref[xi] : zero;
|
33 |
+
|
34 |
+
scalar_t y;
|
35 |
+
|
36 |
+
switch (act * 10 + grad) {
|
37 |
+
default:
|
38 |
+
case 10: y = x; break;
|
39 |
+
case 11: y = x; break;
|
40 |
+
case 12: y = 0.0; break;
|
41 |
+
|
42 |
+
case 30: y = (x > 0.0) ? x : x * alpha; break;
|
43 |
+
case 31: y = (ref > 0.0) ? x : x * alpha; break;
|
44 |
+
case 32: y = 0.0; break;
|
45 |
+
}
|
46 |
+
|
47 |
+
out[xi] = y * scale;
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
|
53 |
+
int act, int grad, float alpha, float scale) {
|
54 |
+
int curDevice = -1;
|
55 |
+
cudaGetDevice(&curDevice);
|
56 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
|
57 |
+
|
58 |
+
auto x = input.contiguous();
|
59 |
+
auto b = bias.contiguous();
|
60 |
+
auto ref = refer.contiguous();
|
61 |
+
|
62 |
+
int use_bias = b.numel() ? 1 : 0;
|
63 |
+
int use_ref = ref.numel() ? 1 : 0;
|
64 |
+
|
65 |
+
int size_x = x.numel();
|
66 |
+
int size_b = b.numel();
|
67 |
+
int step_b = 1;
|
68 |
+
|
69 |
+
for (int i = 1 + 1; i < x.dim(); i++) {
|
70 |
+
step_b *= x.size(i);
|
71 |
+
}
|
72 |
+
|
73 |
+
int loop_x = 4;
|
74 |
+
int block_size = 4 * 32;
|
75 |
+
int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
|
76 |
+
|
77 |
+
auto y = torch::empty_like(x);
|
78 |
+
|
79 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
|
80 |
+
fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
|
81 |
+
y.data_ptr<scalar_t>(),
|
82 |
+
x.data_ptr<scalar_t>(),
|
83 |
+
b.data_ptr<scalar_t>(),
|
84 |
+
ref.data_ptr<scalar_t>(),
|
85 |
+
act,
|
86 |
+
grad,
|
87 |
+
alpha,
|
88 |
+
scale,
|
89 |
+
loop_x,
|
90 |
+
size_x,
|
91 |
+
step_b,
|
92 |
+
size_b,
|
93 |
+
use_bias,
|
94 |
+
use_ref
|
95 |
+
);
|
96 |
+
});
|
97 |
+
|
98 |
+
return y;
|
99 |
+
}
|
encoder4editing/models/stylegan2/op/upfirdn2d.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
|
4 |
+
torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
|
5 |
+
int up_x, int up_y, int down_x, int down_y,
|
6 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1);
|
7 |
+
|
8 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
9 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
10 |
+
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
11 |
+
|
12 |
+
torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
|
13 |
+
int up_x, int up_y, int down_x, int down_y,
|
14 |
+
int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
|
15 |
+
CHECK_CUDA(input);
|
16 |
+
CHECK_CUDA(kernel);
|
17 |
+
|
18 |
+
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
|
19 |
+
}
|
20 |
+
|
21 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
|
23 |
+
}
|