Spaces:
Running
on
Zero
Running
on
Zero
Add reactor
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +0 -1
- app.py +2 -2
- custom_nodes/ComfyUI-ReActor/.gitignore +5 -0
- custom_nodes/ComfyUI-ReActor/LICENSE +674 -0
- custom_nodes/ComfyUI-ReActor/README.md +488 -0
- custom_nodes/ComfyUI-ReActor/README_RU.md +497 -0
- custom_nodes/ComfyUI-ReActor/__init__.py +39 -0
- custom_nodes/ComfyUI-ReActor/install.bat +37 -0
- custom_nodes/ComfyUI-ReActor/install.py +104 -0
- custom_nodes/ComfyUI-ReActor/modules/__init__.py +0 -0
- custom_nodes/ComfyUI-ReActor/modules/images.py +0 -0
- custom_nodes/ComfyUI-ReActor/modules/processing.py +13 -0
- custom_nodes/ComfyUI-ReActor/modules/scripts.py +13 -0
- custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py +0 -0
- custom_nodes/ComfyUI-ReActor/modules/shared.py +19 -0
- custom_nodes/ComfyUI-ReActor/nodes.py +1364 -0
- custom_nodes/ComfyUI-ReActor/pyproject.toml +15 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py +12 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py +25 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py +322 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py +336 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py +407 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py +169 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py +162 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py +150 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py +277 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py +274 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py +61 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py +383 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py +259 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py +255 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py +307 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py +135 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py +184 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py +119 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py +96 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py +65 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py +70 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py +799 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py +956 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py +172 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py +161 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py +101 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py +48 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py +313 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py +768 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py +80 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py +108 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py +125 -0
- custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py +193 -0
.gitignore
CHANGED
@@ -5,7 +5,6 @@ __pycache__/
|
|
5 |
!/input/example.png
|
6 |
/models/
|
7 |
/temp/
|
8 |
-
/custom_nodes/
|
9 |
!custom_nodes/example_node.py.example
|
10 |
extra_model_paths.yaml
|
11 |
/.vs
|
|
|
5 |
!/input/example.png
|
6 |
/models/
|
7 |
/temp/
|
|
|
8 |
!custom_nodes/example_node.py.example
|
9 |
extra_model_paths.yaml
|
10 |
/.vs
|
app.py
CHANGED
@@ -6,10 +6,10 @@ import gradio as gr
|
|
6 |
import torch
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from nodes import NODE_CLASS_MAPPINGS
|
9 |
-
import spaces
|
10 |
from comfy import model_management
|
11 |
|
12 |
-
|
|
|
13 |
|
14 |
|
15 |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
|
|
6 |
import torch
|
7 |
from huggingface_hub import hf_hub_download
|
8 |
from nodes import NODE_CLASS_MAPPINGS
|
|
|
9 |
from comfy import model_management
|
10 |
|
11 |
+
# import spaces
|
12 |
+
# @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
|
13 |
|
14 |
|
15 |
def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
|
custom_nodes/ComfyUI-ReActor/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*$py.class
|
3 |
+
.vscode/
|
4 |
+
example
|
5 |
+
input
|
custom_nodes/ComfyUI-ReActor/LICENSE
ADDED
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
GNU GENERAL PUBLIC LICENSE
|
2 |
+
Version 3, 29 June 2007
|
3 |
+
|
4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
6 |
+
of this license document, but changing it is not allowed.
|
7 |
+
|
8 |
+
Preamble
|
9 |
+
|
10 |
+
The GNU General Public License is a free, copyleft license for
|
11 |
+
software and other kinds of works.
|
12 |
+
|
13 |
+
The licenses for most software and other practical works are designed
|
14 |
+
to take away your freedom to share and change the works. By contrast,
|
15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
16 |
+
share and change all versions of a program--to make sure it remains free
|
17 |
+
software for all its users. We, the Free Software Foundation, use the
|
18 |
+
GNU General Public License for most of our software; it applies also to
|
19 |
+
any other work released this way by its authors. You can apply it to
|
20 |
+
your programs, too.
|
21 |
+
|
22 |
+
When we speak of free software, we are referring to freedom, not
|
23 |
+
price. Our General Public Licenses are designed to make sure that you
|
24 |
+
have the freedom to distribute copies of free software (and charge for
|
25 |
+
them if you wish), that you receive source code or can get it if you
|
26 |
+
want it, that you can change the software or use pieces of it in new
|
27 |
+
free programs, and that you know you can do these things.
|
28 |
+
|
29 |
+
To protect your rights, we need to prevent others from denying you
|
30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
31 |
+
certain responsibilities if you distribute copies of the software, or if
|
32 |
+
you modify it: responsibilities to respect the freedom of others.
|
33 |
+
|
34 |
+
For example, if you distribute copies of such a program, whether
|
35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
36 |
+
freedoms that you received. You must make sure that they, too, receive
|
37 |
+
or can get the source code. And you must show them these terms so they
|
38 |
+
know their rights.
|
39 |
+
|
40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
43 |
+
|
44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
45 |
+
that there is no warranty for this free software. For both users' and
|
46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
47 |
+
changed, so that their problems will not be attributed erroneously to
|
48 |
+
authors of previous versions.
|
49 |
+
|
50 |
+
Some devices are designed to deny users access to install or run
|
51 |
+
modified versions of the software inside them, although the manufacturer
|
52 |
+
can do so. This is fundamentally incompatible with the aim of
|
53 |
+
protecting users' freedom to change the software. The systematic
|
54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
56 |
+
have designed this version of the GPL to prohibit the practice for those
|
57 |
+
products. If such problems arise substantially in other domains, we
|
58 |
+
stand ready to extend this provision to those domains in future versions
|
59 |
+
of the GPL, as needed to protect the freedom of users.
|
60 |
+
|
61 |
+
Finally, every program is threatened constantly by software patents.
|
62 |
+
States should not allow patents to restrict development and use of
|
63 |
+
software on general-purpose computers, but in those that do, we wish to
|
64 |
+
avoid the special danger that patents applied to a free program could
|
65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
66 |
+
patents cannot be used to render the program non-free.
|
67 |
+
|
68 |
+
The precise terms and conditions for copying, distribution and
|
69 |
+
modification follow.
|
70 |
+
|
71 |
+
TERMS AND CONDITIONS
|
72 |
+
|
73 |
+
0. Definitions.
|
74 |
+
|
75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
76 |
+
|
77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
78 |
+
works, such as semiconductor masks.
|
79 |
+
|
80 |
+
"The Program" refers to any copyrightable work licensed under this
|
81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
82 |
+
"recipients" may be individuals or organizations.
|
83 |
+
|
84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
85 |
+
in a fashion requiring copyright permission, other than the making of an
|
86 |
+
exact copy. The resulting work is called a "modified version" of the
|
87 |
+
earlier work or a work "based on" the earlier work.
|
88 |
+
|
89 |
+
A "covered work" means either the unmodified Program or a work based
|
90 |
+
on the Program.
|
91 |
+
|
92 |
+
To "propagate" a work means to do anything with it that, without
|
93 |
+
permission, would make you directly or secondarily liable for
|
94 |
+
infringement under applicable copyright law, except executing it on a
|
95 |
+
computer or modifying a private copy. Propagation includes copying,
|
96 |
+
distribution (with or without modification), making available to the
|
97 |
+
public, and in some countries other activities as well.
|
98 |
+
|
99 |
+
To "convey" a work means any kind of propagation that enables other
|
100 |
+
parties to make or receive copies. Mere interaction with a user through
|
101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
102 |
+
|
103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
104 |
+
to the extent that it includes a convenient and prominently visible
|
105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
106 |
+
tells the user that there is no warranty for the work (except to the
|
107 |
+
extent that warranties are provided), that licensees may convey the
|
108 |
+
work under this License, and how to view a copy of this License. If
|
109 |
+
the interface presents a list of user commands or options, such as a
|
110 |
+
menu, a prominent item in the list meets this criterion.
|
111 |
+
|
112 |
+
1. Source Code.
|
113 |
+
|
114 |
+
The "source code" for a work means the preferred form of the work
|
115 |
+
for making modifications to it. "Object code" means any non-source
|
116 |
+
form of a work.
|
117 |
+
|
118 |
+
A "Standard Interface" means an interface that either is an official
|
119 |
+
standard defined by a recognized standards body, or, in the case of
|
120 |
+
interfaces specified for a particular programming language, one that
|
121 |
+
is widely used among developers working in that language.
|
122 |
+
|
123 |
+
The "System Libraries" of an executable work include anything, other
|
124 |
+
than the work as a whole, that (a) is included in the normal form of
|
125 |
+
packaging a Major Component, but which is not part of that Major
|
126 |
+
Component, and (b) serves only to enable use of the work with that
|
127 |
+
Major Component, or to implement a Standard Interface for which an
|
128 |
+
implementation is available to the public in source code form. A
|
129 |
+
"Major Component", in this context, means a major essential component
|
130 |
+
(kernel, window system, and so on) of the specific operating system
|
131 |
+
(if any) on which the executable work runs, or a compiler used to
|
132 |
+
produce the work, or an object code interpreter used to run it.
|
133 |
+
|
134 |
+
The "Corresponding Source" for a work in object code form means all
|
135 |
+
the source code needed to generate, install, and (for an executable
|
136 |
+
work) run the object code and to modify the work, including scripts to
|
137 |
+
control those activities. However, it does not include the work's
|
138 |
+
System Libraries, or general-purpose tools or generally available free
|
139 |
+
programs which are used unmodified in performing those activities but
|
140 |
+
which are not part of the work. For example, Corresponding Source
|
141 |
+
includes interface definition files associated with source files for
|
142 |
+
the work, and the source code for shared libraries and dynamically
|
143 |
+
linked subprograms that the work is specifically designed to require,
|
144 |
+
such as by intimate data communication or control flow between those
|
145 |
+
subprograms and other parts of the work.
|
146 |
+
|
147 |
+
The Corresponding Source need not include anything that users
|
148 |
+
can regenerate automatically from other parts of the Corresponding
|
149 |
+
Source.
|
150 |
+
|
151 |
+
The Corresponding Source for a work in source code form is that
|
152 |
+
same work.
|
153 |
+
|
154 |
+
2. Basic Permissions.
|
155 |
+
|
156 |
+
All rights granted under this License are granted for the term of
|
157 |
+
copyright on the Program, and are irrevocable provided the stated
|
158 |
+
conditions are met. This License explicitly affirms your unlimited
|
159 |
+
permission to run the unmodified Program. The output from running a
|
160 |
+
covered work is covered by this License only if the output, given its
|
161 |
+
content, constitutes a covered work. This License acknowledges your
|
162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
163 |
+
|
164 |
+
You may make, run and propagate covered works that you do not
|
165 |
+
convey, without conditions so long as your license otherwise remains
|
166 |
+
in force. You may convey covered works to others for the sole purpose
|
167 |
+
of having them make modifications exclusively for you, or provide you
|
168 |
+
with facilities for running those works, provided that you comply with
|
169 |
+
the terms of this License in conveying all material for which you do
|
170 |
+
not control copyright. Those thus making or running the covered works
|
171 |
+
for you must do so exclusively on your behalf, under your direction
|
172 |
+
and control, on terms that prohibit them from making any copies of
|
173 |
+
your copyrighted material outside their relationship with you.
|
174 |
+
|
175 |
+
Conveying under any other circumstances is permitted solely under
|
176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
177 |
+
makes it unnecessary.
|
178 |
+
|
179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
180 |
+
|
181 |
+
No covered work shall be deemed part of an effective technological
|
182 |
+
measure under any applicable law fulfilling obligations under article
|
183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
184 |
+
similar laws prohibiting or restricting circumvention of such
|
185 |
+
measures.
|
186 |
+
|
187 |
+
When you convey a covered work, you waive any legal power to forbid
|
188 |
+
circumvention of technological measures to the extent such circumvention
|
189 |
+
is effected by exercising rights under this License with respect to
|
190 |
+
the covered work, and you disclaim any intention to limit operation or
|
191 |
+
modification of the work as a means of enforcing, against the work's
|
192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
193 |
+
technological measures.
|
194 |
+
|
195 |
+
4. Conveying Verbatim Copies.
|
196 |
+
|
197 |
+
You may convey verbatim copies of the Program's source code as you
|
198 |
+
receive it, in any medium, provided that you conspicuously and
|
199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
200 |
+
keep intact all notices stating that this License and any
|
201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
202 |
+
keep intact all notices of the absence of any warranty; and give all
|
203 |
+
recipients a copy of this License along with the Program.
|
204 |
+
|
205 |
+
You may charge any price or no price for each copy that you convey,
|
206 |
+
and you may offer support or warranty protection for a fee.
|
207 |
+
|
208 |
+
5. Conveying Modified Source Versions.
|
209 |
+
|
210 |
+
You may convey a work based on the Program, or the modifications to
|
211 |
+
produce it from the Program, in the form of source code under the
|
212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
213 |
+
|
214 |
+
a) The work must carry prominent notices stating that you modified
|
215 |
+
it, and giving a relevant date.
|
216 |
+
|
217 |
+
b) The work must carry prominent notices stating that it is
|
218 |
+
released under this License and any conditions added under section
|
219 |
+
7. This requirement modifies the requirement in section 4 to
|
220 |
+
"keep intact all notices".
|
221 |
+
|
222 |
+
c) You must license the entire work, as a whole, under this
|
223 |
+
License to anyone who comes into possession of a copy. This
|
224 |
+
License will therefore apply, along with any applicable section 7
|
225 |
+
additional terms, to the whole of the work, and all its parts,
|
226 |
+
regardless of how they are packaged. This License gives no
|
227 |
+
permission to license the work in any other way, but it does not
|
228 |
+
invalidate such permission if you have separately received it.
|
229 |
+
|
230 |
+
d) If the work has interactive user interfaces, each must display
|
231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
233 |
+
work need not make them do so.
|
234 |
+
|
235 |
+
A compilation of a covered work with other separate and independent
|
236 |
+
works, which are not by their nature extensions of the covered work,
|
237 |
+
and which are not combined with it such as to form a larger program,
|
238 |
+
in or on a volume of a storage or distribution medium, is called an
|
239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
240 |
+
used to limit the access or legal rights of the compilation's users
|
241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
242 |
+
in an aggregate does not cause this License to apply to the other
|
243 |
+
parts of the aggregate.
|
244 |
+
|
245 |
+
6. Conveying Non-Source Forms.
|
246 |
+
|
247 |
+
You may convey a covered work in object code form under the terms
|
248 |
+
of sections 4 and 5, provided that you also convey the
|
249 |
+
machine-readable Corresponding Source under the terms of this License,
|
250 |
+
in one of these ways:
|
251 |
+
|
252 |
+
a) Convey the object code in, or embodied in, a physical product
|
253 |
+
(including a physical distribution medium), accompanied by the
|
254 |
+
Corresponding Source fixed on a durable physical medium
|
255 |
+
customarily used for software interchange.
|
256 |
+
|
257 |
+
b) Convey the object code in, or embodied in, a physical product
|
258 |
+
(including a physical distribution medium), accompanied by a
|
259 |
+
written offer, valid for at least three years and valid for as
|
260 |
+
long as you offer spare parts or customer support for that product
|
261 |
+
model, to give anyone who possesses the object code either (1) a
|
262 |
+
copy of the Corresponding Source for all the software in the
|
263 |
+
product that is covered by this License, on a durable physical
|
264 |
+
medium customarily used for software interchange, for a price no
|
265 |
+
more than your reasonable cost of physically performing this
|
266 |
+
conveying of source, or (2) access to copy the
|
267 |
+
Corresponding Source from a network server at no charge.
|
268 |
+
|
269 |
+
c) Convey individual copies of the object code with a copy of the
|
270 |
+
written offer to provide the Corresponding Source. This
|
271 |
+
alternative is allowed only occasionally and noncommercially, and
|
272 |
+
only if you received the object code with such an offer, in accord
|
273 |
+
with subsection 6b.
|
274 |
+
|
275 |
+
d) Convey the object code by offering access from a designated
|
276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
277 |
+
Corresponding Source in the same way through the same place at no
|
278 |
+
further charge. You need not require recipients to copy the
|
279 |
+
Corresponding Source along with the object code. If the place to
|
280 |
+
copy the object code is a network server, the Corresponding Source
|
281 |
+
may be on a different server (operated by you or a third party)
|
282 |
+
that supports equivalent copying facilities, provided you maintain
|
283 |
+
clear directions next to the object code saying where to find the
|
284 |
+
Corresponding Source. Regardless of what server hosts the
|
285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
286 |
+
available for as long as needed to satisfy these requirements.
|
287 |
+
|
288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
289 |
+
you inform other peers where the object code and Corresponding
|
290 |
+
Source of the work are being offered to the general public at no
|
291 |
+
charge under subsection 6d.
|
292 |
+
|
293 |
+
A separable portion of the object code, whose source code is excluded
|
294 |
+
from the Corresponding Source as a System Library, need not be
|
295 |
+
included in conveying the object code work.
|
296 |
+
|
297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
298 |
+
tangible personal property which is normally used for personal, family,
|
299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
302 |
+
product received by a particular user, "normally used" refers to a
|
303 |
+
typical or common use of that class of product, regardless of the status
|
304 |
+
of the particular user or of the way in which the particular user
|
305 |
+
actually uses, or expects or is expected to use, the product. A product
|
306 |
+
is a consumer product regardless of whether the product has substantial
|
307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
308 |
+
the only significant mode of use of the product.
|
309 |
+
|
310 |
+
"Installation Information" for a User Product means any methods,
|
311 |
+
procedures, authorization keys, or other information required to install
|
312 |
+
and execute modified versions of a covered work in that User Product from
|
313 |
+
a modified version of its Corresponding Source. The information must
|
314 |
+
suffice to ensure that the continued functioning of the modified object
|
315 |
+
code is in no case prevented or interfered with solely because
|
316 |
+
modification has been made.
|
317 |
+
|
318 |
+
If you convey an object code work under this section in, or with, or
|
319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
320 |
+
part of a transaction in which the right of possession and use of the
|
321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
322 |
+
fixed term (regardless of how the transaction is characterized), the
|
323 |
+
Corresponding Source conveyed under this section must be accompanied
|
324 |
+
by the Installation Information. But this requirement does not apply
|
325 |
+
if neither you nor any third party retains the ability to install
|
326 |
+
modified object code on the User Product (for example, the work has
|
327 |
+
been installed in ROM).
|
328 |
+
|
329 |
+
The requirement to provide Installation Information does not include a
|
330 |
+
requirement to continue to provide support service, warranty, or updates
|
331 |
+
for a work that has been modified or installed by the recipient, or for
|
332 |
+
the User Product in which it has been modified or installed. Access to a
|
333 |
+
network may be denied when the modification itself materially and
|
334 |
+
adversely affects the operation of the network or violates the rules and
|
335 |
+
protocols for communication across the network.
|
336 |
+
|
337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
338 |
+
in accord with this section must be in a format that is publicly
|
339 |
+
documented (and with an implementation available to the public in
|
340 |
+
source code form), and must require no special password or key for
|
341 |
+
unpacking, reading or copying.
|
342 |
+
|
343 |
+
7. Additional Terms.
|
344 |
+
|
345 |
+
"Additional permissions" are terms that supplement the terms of this
|
346 |
+
License by making exceptions from one or more of its conditions.
|
347 |
+
Additional permissions that are applicable to the entire Program shall
|
348 |
+
be treated as though they were included in this License, to the extent
|
349 |
+
that they are valid under applicable law. If additional permissions
|
350 |
+
apply only to part of the Program, that part may be used separately
|
351 |
+
under those permissions, but the entire Program remains governed by
|
352 |
+
this License without regard to the additional permissions.
|
353 |
+
|
354 |
+
When you convey a copy of a covered work, you may at your option
|
355 |
+
remove any additional permissions from that copy, or from any part of
|
356 |
+
it. (Additional permissions may be written to require their own
|
357 |
+
removal in certain cases when you modify the work.) You may place
|
358 |
+
additional permissions on material, added by you to a covered work,
|
359 |
+
for which you have or can give appropriate copyright permission.
|
360 |
+
|
361 |
+
Notwithstanding any other provision of this License, for material you
|
362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
363 |
+
that material) supplement the terms of this License with terms:
|
364 |
+
|
365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
366 |
+
terms of sections 15 and 16 of this License; or
|
367 |
+
|
368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
369 |
+
author attributions in that material or in the Appropriate Legal
|
370 |
+
Notices displayed by works containing it; or
|
371 |
+
|
372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
373 |
+
requiring that modified versions of such material be marked in
|
374 |
+
reasonable ways as different from the original version; or
|
375 |
+
|
376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
377 |
+
authors of the material; or
|
378 |
+
|
379 |
+
e) Declining to grant rights under trademark law for use of some
|
380 |
+
trade names, trademarks, or service marks; or
|
381 |
+
|
382 |
+
f) Requiring indemnification of licensors and authors of that
|
383 |
+
material by anyone who conveys the material (or modified versions of
|
384 |
+
it) with contractual assumptions of liability to the recipient, for
|
385 |
+
any liability that these contractual assumptions directly impose on
|
386 |
+
those licensors and authors.
|
387 |
+
|
388 |
+
All other non-permissive additional terms are considered "further
|
389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
390 |
+
received it, or any part of it, contains a notice stating that it is
|
391 |
+
governed by this License along with a term that is a further
|
392 |
+
restriction, you may remove that term. If a license document contains
|
393 |
+
a further restriction but permits relicensing or conveying under this
|
394 |
+
License, you may add to a covered work material governed by the terms
|
395 |
+
of that license document, provided that the further restriction does
|
396 |
+
not survive such relicensing or conveying.
|
397 |
+
|
398 |
+
If you add terms to a covered work in accord with this section, you
|
399 |
+
must place, in the relevant source files, a statement of the
|
400 |
+
additional terms that apply to those files, or a notice indicating
|
401 |
+
where to find the applicable terms.
|
402 |
+
|
403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
404 |
+
form of a separately written license, or stated as exceptions;
|
405 |
+
the above requirements apply either way.
|
406 |
+
|
407 |
+
8. Termination.
|
408 |
+
|
409 |
+
You may not propagate or modify a covered work except as expressly
|
410 |
+
provided under this License. Any attempt otherwise to propagate or
|
411 |
+
modify it is void, and will automatically terminate your rights under
|
412 |
+
this License (including any patent licenses granted under the third
|
413 |
+
paragraph of section 11).
|
414 |
+
|
415 |
+
However, if you cease all violation of this License, then your
|
416 |
+
license from a particular copyright holder is reinstated (a)
|
417 |
+
provisionally, unless and until the copyright holder explicitly and
|
418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
419 |
+
holder fails to notify you of the violation by some reasonable means
|
420 |
+
prior to 60 days after the cessation.
|
421 |
+
|
422 |
+
Moreover, your license from a particular copyright holder is
|
423 |
+
reinstated permanently if the copyright holder notifies you of the
|
424 |
+
violation by some reasonable means, this is the first time you have
|
425 |
+
received notice of violation of this License (for any work) from that
|
426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
427 |
+
your receipt of the notice.
|
428 |
+
|
429 |
+
Termination of your rights under this section does not terminate the
|
430 |
+
licenses of parties who have received copies or rights from you under
|
431 |
+
this License. If your rights have been terminated and not permanently
|
432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
433 |
+
material under section 10.
|
434 |
+
|
435 |
+
9. Acceptance Not Required for Having Copies.
|
436 |
+
|
437 |
+
You are not required to accept this License in order to receive or
|
438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
440 |
+
to receive a copy likewise does not require acceptance. However,
|
441 |
+
nothing other than this License grants you permission to propagate or
|
442 |
+
modify any covered work. These actions infringe copyright if you do
|
443 |
+
not accept this License. Therefore, by modifying or propagating a
|
444 |
+
covered work, you indicate your acceptance of this License to do so.
|
445 |
+
|
446 |
+
10. Automatic Licensing of Downstream Recipients.
|
447 |
+
|
448 |
+
Each time you convey a covered work, the recipient automatically
|
449 |
+
receives a license from the original licensors, to run, modify and
|
450 |
+
propagate that work, subject to this License. You are not responsible
|
451 |
+
for enforcing compliance by third parties with this License.
|
452 |
+
|
453 |
+
An "entity transaction" is a transaction transferring control of an
|
454 |
+
organization, or substantially all assets of one, or subdividing an
|
455 |
+
organization, or merging organizations. If propagation of a covered
|
456 |
+
work results from an entity transaction, each party to that
|
457 |
+
transaction who receives a copy of the work also receives whatever
|
458 |
+
licenses to the work the party's predecessor in interest had or could
|
459 |
+
give under the previous paragraph, plus a right to possession of the
|
460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
461 |
+
the predecessor has it or can get it with reasonable efforts.
|
462 |
+
|
463 |
+
You may not impose any further restrictions on the exercise of the
|
464 |
+
rights granted or affirmed under this License. For example, you may
|
465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
466 |
+
rights granted under this License, and you may not initiate litigation
|
467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
468 |
+
any patent claim is infringed by making, using, selling, offering for
|
469 |
+
sale, or importing the Program or any portion of it.
|
470 |
+
|
471 |
+
11. Patents.
|
472 |
+
|
473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
474 |
+
License of the Program or a work on which the Program is based. The
|
475 |
+
work thus licensed is called the contributor's "contributor version".
|
476 |
+
|
477 |
+
A contributor's "essential patent claims" are all patent claims
|
478 |
+
owned or controlled by the contributor, whether already acquired or
|
479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
480 |
+
by this License, of making, using, or selling its contributor version,
|
481 |
+
but do not include claims that would be infringed only as a
|
482 |
+
consequence of further modification of the contributor version. For
|
483 |
+
purposes of this definition, "control" includes the right to grant
|
484 |
+
patent sublicenses in a manner consistent with the requirements of
|
485 |
+
this License.
|
486 |
+
|
487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
488 |
+
patent license under the contributor's essential patent claims, to
|
489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
490 |
+
propagate the contents of its contributor version.
|
491 |
+
|
492 |
+
In the following three paragraphs, a "patent license" is any express
|
493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
494 |
+
(such as an express permission to practice a patent or covenant not to
|
495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
496 |
+
party means to make such an agreement or commitment not to enforce a
|
497 |
+
patent against the party.
|
498 |
+
|
499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
500 |
+
and the Corresponding Source of the work is not available for anyone
|
501 |
+
to copy, free of charge and under the terms of this License, through a
|
502 |
+
publicly available network server or other readily accessible means,
|
503 |
+
then you must either (1) cause the Corresponding Source to be so
|
504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
506 |
+
consistent with the requirements of this License, to extend the patent
|
507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
508 |
+
actual knowledge that, but for the patent license, your conveying the
|
509 |
+
covered work in a country, or your recipient's use of the covered work
|
510 |
+
in a country, would infringe one or more identifiable patents in that
|
511 |
+
country that you have reason to believe are valid.
|
512 |
+
|
513 |
+
If, pursuant to or in connection with a single transaction or
|
514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
515 |
+
covered work, and grant a patent license to some of the parties
|
516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
517 |
+
or convey a specific copy of the covered work, then the patent license
|
518 |
+
you grant is automatically extended to all recipients of the covered
|
519 |
+
work and works based on it.
|
520 |
+
|
521 |
+
A patent license is "discriminatory" if it does not include within
|
522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
524 |
+
specifically granted under this License. You may not convey a covered
|
525 |
+
work if you are a party to an arrangement with a third party that is
|
526 |
+
in the business of distributing software, under which you make payment
|
527 |
+
to the third party based on the extent of your activity of conveying
|
528 |
+
the work, and under which the third party grants, to any of the
|
529 |
+
parties who would receive the covered work from you, a discriminatory
|
530 |
+
patent license (a) in connection with copies of the covered work
|
531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
532 |
+
for and in connection with specific products or compilations that
|
533 |
+
contain the covered work, unless you entered into that arrangement,
|
534 |
+
or that patent license was granted, prior to 28 March 2007.
|
535 |
+
|
536 |
+
Nothing in this License shall be construed as excluding or limiting
|
537 |
+
any implied license or other defenses to infringement that may
|
538 |
+
otherwise be available to you under applicable patent law.
|
539 |
+
|
540 |
+
12. No Surrender of Others' Freedom.
|
541 |
+
|
542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
543 |
+
otherwise) that contradict the conditions of this License, they do not
|
544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
546 |
+
License and any other pertinent obligations, then as a consequence you may
|
547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
548 |
+
to collect a royalty for further conveying from those to whom you convey
|
549 |
+
the Program, the only way you could satisfy both those terms and this
|
550 |
+
License would be to refrain entirely from conveying the Program.
|
551 |
+
|
552 |
+
13. Use with the GNU Affero General Public License.
|
553 |
+
|
554 |
+
Notwithstanding any other provision of this License, you have
|
555 |
+
permission to link or combine any covered work with a work licensed
|
556 |
+
under version 3 of the GNU Affero General Public License into a single
|
557 |
+
combined work, and to convey the resulting work. The terms of this
|
558 |
+
License will continue to apply to the part which is the covered work,
|
559 |
+
but the special requirements of the GNU Affero General Public License,
|
560 |
+
section 13, concerning interaction through a network will apply to the
|
561 |
+
combination as such.
|
562 |
+
|
563 |
+
14. Revised Versions of this License.
|
564 |
+
|
565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
566 |
+
the GNU General Public License from time to time. Such new versions will
|
567 |
+
be similar in spirit to the present version, but may differ in detail to
|
568 |
+
address new problems or concerns.
|
569 |
+
|
570 |
+
Each version is given a distinguishing version number. If the
|
571 |
+
Program specifies that a certain numbered version of the GNU General
|
572 |
+
Public License "or any later version" applies to it, you have the
|
573 |
+
option of following the terms and conditions either of that numbered
|
574 |
+
version or of any later version published by the Free Software
|
575 |
+
Foundation. If the Program does not specify a version number of the
|
576 |
+
GNU General Public License, you may choose any version ever published
|
577 |
+
by the Free Software Foundation.
|
578 |
+
|
579 |
+
If the Program specifies that a proxy can decide which future
|
580 |
+
versions of the GNU General Public License can be used, that proxy's
|
581 |
+
public statement of acceptance of a version permanently authorizes you
|
582 |
+
to choose that version for the Program.
|
583 |
+
|
584 |
+
Later license versions may give you additional or different
|
585 |
+
permissions. However, no additional obligations are imposed on any
|
586 |
+
author or copyright holder as a result of your choosing to follow a
|
587 |
+
later version.
|
588 |
+
|
589 |
+
15. Disclaimer of Warranty.
|
590 |
+
|
591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
599 |
+
|
600 |
+
16. Limitation of Liability.
|
601 |
+
|
602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
610 |
+
SUCH DAMAGES.
|
611 |
+
|
612 |
+
17. Interpretation of Sections 15 and 16.
|
613 |
+
|
614 |
+
If the disclaimer of warranty and limitation of liability provided
|
615 |
+
above cannot be given local legal effect according to their terms,
|
616 |
+
reviewing courts shall apply local law that most closely approximates
|
617 |
+
an absolute waiver of all civil liability in connection with the
|
618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
619 |
+
copy of the Program in return for a fee.
|
620 |
+
|
621 |
+
END OF TERMS AND CONDITIONS
|
622 |
+
|
623 |
+
How to Apply These Terms to Your New Programs
|
624 |
+
|
625 |
+
If you develop a new program, and you want it to be of the greatest
|
626 |
+
possible use to the public, the best way to achieve this is to make it
|
627 |
+
free software which everyone can redistribute and change under these terms.
|
628 |
+
|
629 |
+
To do so, attach the following notices to the program. It is safest
|
630 |
+
to attach them to the start of each source file to most effectively
|
631 |
+
state the exclusion of warranty; and each file should have at least
|
632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
633 |
+
|
634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
635 |
+
Copyright (C) <year> <name of author>
|
636 |
+
|
637 |
+
This program is free software: you can redistribute it and/or modify
|
638 |
+
it under the terms of the GNU General Public License as published by
|
639 |
+
the Free Software Foundation, either version 3 of the License, or
|
640 |
+
(at your option) any later version.
|
641 |
+
|
642 |
+
This program is distributed in the hope that it will be useful,
|
643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
645 |
+
GNU General Public License for more details.
|
646 |
+
|
647 |
+
You should have received a copy of the GNU General Public License
|
648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
649 |
+
|
650 |
+
Also add information on how to contact you by electronic and paper mail.
|
651 |
+
|
652 |
+
If the program does terminal interaction, make it output a short
|
653 |
+
notice like this when it starts in an interactive mode:
|
654 |
+
|
655 |
+
<program> Copyright (C) <year> <name of author>
|
656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
657 |
+
This is free software, and you are welcome to redistribute it
|
658 |
+
under certain conditions; type `show c' for details.
|
659 |
+
|
660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
661 |
+
parts of the General Public License. Of course, your program's commands
|
662 |
+
might be different; for a GUI interface, you would use an "about box".
|
663 |
+
|
664 |
+
You should also get your employer (if you work as a programmer) or school,
|
665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
667 |
+
<https://www.gnu.org/licenses/>.
|
668 |
+
|
669 |
+
The GNU General Public License does not permit incorporating your program
|
670 |
+
into proprietary programs. If your program is a subroutine library, you
|
671 |
+
may consider it more useful to permit linking proprietary applications with
|
672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
673 |
+
Public License instead of this License. But first, please read
|
674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
custom_nodes/ComfyUI-ReActor/README.md
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<img src="https://github.com/Gourieff/Assets/raw/main/sd-webui-reactor/ReActor_logo_NEW_EN.png?raw=true" alt="logo" width="180px"/>
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
<!--<sup>
|
8 |
+
<font color=brightred>
|
9 |
+
|
10 |
+
## !!! [Important Update](#latestupdate) !!!<br>Don't forget to add the Node again in existing workflows
|
11 |
+
|
12 |
+
</font>
|
13 |
+
</sup>-->
|
14 |
+
|
15 |
+
<a href="https://boosty.to/artgourieff" target="_blank">
|
16 |
+
<img src="https://lovemet.ru/img/boosty.jpg" width="108" alt="Support Me on Boosty"/>
|
17 |
+
<br>
|
18 |
+
<sup>
|
19 |
+
Support This Project
|
20 |
+
</sup>
|
21 |
+
</a>
|
22 |
+
|
23 |
+
<hr>
|
24 |
+
|
25 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/commits/main)
|
26 |
+

|
27 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0)
|
28 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed)
|
29 |
+

|
30 |
+
|
31 |
+
English | [Русский](/README_RU.md)
|
32 |
+
|
33 |
+
# ReActor Nodes for ComfyUI<br><sub><sup>-=SFW-Friendly=-</sup></sub>
|
34 |
+
|
35 |
+
</div>
|
36 |
+
|
37 |
+
### The Fast and Simple Face Swap Extension Nodes for ComfyUI, based on [blocked ReActor](https://github.com/Gourieff/comfyui-reactor-node) - now it has a nudity detector to avoid using this software with 18+ content
|
38 |
+
|
39 |
+
> By using this Node you accept and assume [responsibility](#disclaimer)
|
40 |
+
|
41 |
+
<div align="center">
|
42 |
+
|
43 |
+
---
|
44 |
+
[**What's new**](#latestupdate) | [**Installation**](#installation) | [**Usage**](#usage) | [**Troubleshooting**](#troubleshooting) | [**Updating**](#updating) | [**Disclaimer**](#disclaimer) | [**Credits**](#credits) | [**Note!**](#note)
|
45 |
+
|
46 |
+
---
|
47 |
+
|
48 |
+
</div>
|
49 |
+
|
50 |
+
<a name="latestupdate">
|
51 |
+
|
52 |
+
## What's new in the latest update
|
53 |
+
|
54 |
+
### 0.6.0 <sub><sup>ALPHA1</sup></sub>
|
55 |
+
|
56 |
+
- New Node `ReActorSetWeight` - you can now set the strength of face swap for `source_image` or `face_model` from 0% to 100% (in 12.5% step)
|
57 |
+
|
58 |
+
<center>
|
59 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-01.jpg?raw=true" alt="0.6.0-whatsnew-01" width="100%"/>
|
60 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-02.jpg?raw=true" alt="0.6.0-whatsnew-02" width="100%"/>
|
61 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-alpha1-01.gif?raw=true" alt="0.6.0-whatsnew-03" width="540px"/>
|
62 |
+
</center>
|
63 |
+
|
64 |
+
<details>
|
65 |
+
<summary><a>Previous versions</a></summary>
|
66 |
+
|
67 |
+
### 0.5.2
|
68 |
+
|
69 |
+
- ReSwapper models support. Although Inswapper still has the best similarity, but ReSwapper is evolving - thanks @somanchiu https://github.com/somanchiu/ReSwapper for the ReSwapper models and the ReSwapper project! This is a good step for the Community in the Inswapper's alternative creation!
|
70 |
+
|
71 |
+
<center>
|
72 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-03.jpg?raw=true" alt="0.5.2-whatsnew-03" width="75%"/>
|
73 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-04.jpg?raw=true" alt="0.5.2-whatsnew-04" width="75%"/>
|
74 |
+
</center>
|
75 |
+
|
76 |
+
You can download ReSwapper models here:
|
77 |
+
https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models
|
78 |
+
Just put them into the "models/reswapper" directory.
|
79 |
+
|
80 |
+
- NSFW-detector to not violate [GitHub rules](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools)
|
81 |
+
- New node "Unload ReActor Models" - is useful for complex WFs when you need to free some VRAM utilized by ReActor
|
82 |
+
|
83 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-01.jpg?raw=true" alt="0.5.2-whatsnew-01" width="100%"/>
|
84 |
+
|
85 |
+
- Support of ORT CoreML and ROCM EPs, just install onnxruntime version you need
|
86 |
+
- Install script improvements to install latest versions of ORT-GPU
|
87 |
+
|
88 |
+
<center>
|
89 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-02.jpg?raw=true" alt="0.5.2-whatsnew-02" width="50%"/>
|
90 |
+
</center>
|
91 |
+
|
92 |
+
- Fixes and improvements
|
93 |
+
|
94 |
+
|
95 |
+
### 0.5.1
|
96 |
+
|
97 |
+
- Support of GPEN 1024/2048 restoration models (available in the HF dataset https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models)
|
98 |
+
- ReActorFaceBoost Node - an attempt to improve the quality of swapped faces. The idea is to restore and scale the swapped face (according to the `face_size` parameter of the restoration model) BEFORE pasting it to the target image (via inswapper algorithms), more information is [here (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321)
|
99 |
+
|
100 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-01.jpg?raw=true" alt="0.5.1-whatsnew-01" width="100%"/>
|
101 |
+
|
102 |
+
[Full size demo preview](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png)
|
103 |
+
|
104 |
+
- Sorting facemodels alphabetically
|
105 |
+
- A lot of fixes and improvements
|
106 |
+
|
107 |
+
### [0.5.0 <sub><sup>BETA4</sup></sub>](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0)
|
108 |
+
|
109 |
+
- Spandrel lib support for GFPGAN
|
110 |
+
|
111 |
+
### 0.5.0 <sub><sup>BETA3</sup></sub>
|
112 |
+
|
113 |
+
- Fixes: "RAM issue", "No detection" for MaskingHelper
|
114 |
+
|
115 |
+
### 0.5.0 <sub><sup>BETA2</sup></sub>
|
116 |
+
|
117 |
+
- You can now build a blended face model from a batch of face models you already have, just add the "Make Face Model Batch" node to your workflow and connect several models via "Load Face Model"
|
118 |
+
- Huge performance boost of the image analyzer's module! 10x speed up! Working with videos is now a pleasure!
|
119 |
+
|
120 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-05.png?raw=true" alt="0.5.0-whatsnew-05" width="100%"/>
|
121 |
+
|
122 |
+
### 0.5.0 <sub><sup>BETA1</sup></sub>
|
123 |
+
|
124 |
+
- SWAPPED_FACE output for the Masking Helper Node
|
125 |
+
- FIX: Empty A-channel for Masking Helper IMAGE output (causing errors with some nodes) was removed
|
126 |
+
|
127 |
+
### 0.5.0 <sub><sup>ALPHA1</sup></sub>
|
128 |
+
|
129 |
+
- ReActorBuildFaceModel Node got "face_model" output to provide a blended face model directly to the main Node:
|
130 |
+
|
131 |
+
Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json)
|
132 |
+
|
133 |
+
- Face Masking feature is available now, just add the "ReActorMaskHelper" Node to the workflow and connect it as shown below:
|
134 |
+
|
135 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-01.jpg?raw=true" alt="0.5.0-whatsnew-01" width="100%"/>
|
136 |
+
|
137 |
+
If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory
|
138 |
+
<br>
|
139 |
+
As well as ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) model - download (if you don't have it) and put it into the "ComfyUI\models\sams" directory;
|
140 |
+
|
141 |
+
Use this Node to gain the best results of the face swapping process:
|
142 |
+
|
143 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-02.jpg?raw=true" alt="0.5.0-whatsnew-02" width="100%"/>
|
144 |
+
|
145 |
+
- ReActorImageDublicator Node - rather useful for those who create videos, it helps to duplicate one image to several frames to use them with VAE Encoder (e.g. live avatars):
|
146 |
+
|
147 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-03.jpg?raw=true" alt="0.5.0-whatsnew-03" width="100%"/>
|
148 |
+
|
149 |
+
- ReActorFaceSwapOpt (a simplified version of the Main Node) + ReActorOptions Nodes to set some additional options such as (new) "input/source faces separate order". Yes! You can now set the order of faces in the index in the way you want ("large to small" goes by default)!
|
150 |
+
|
151 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-04.jpg?raw=true" alt="0.5.0-whatsnew-04" width="100%"/>
|
152 |
+
|
153 |
+
- Little speed boost when analyzing target images (unfortunately it is still quite slow in compare to swapping and restoring...)
|
154 |
+
|
155 |
+
### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2)
|
156 |
+
|
157 |
+
- GPEN-BFR-512 and RestoreFormer_Plus_Plus face restoration models support
|
158 |
+
|
159 |
+
You can download models here: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
|
160 |
+
<br>Put them into the `ComfyUI\models\facerestore_models` folder
|
161 |
+
|
162 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-04.jpg?raw=true" alt="0.4.2-whatsnew-04" width="100%"/>
|
163 |
+
|
164 |
+
- Due to popular demand - you can now blend several images with persons into one face model file and use it with "Load Face Model" Node or in SD WebUI as well;
|
165 |
+
|
166 |
+
Experiment and create new faces or blend faces of one person to gain better accuracy and likeness!
|
167 |
+
|
168 |
+
Just add the ImpactPack's "Make Image Batch" Node as the input to the ReActor's one and load images you want to blend into one model:
|
169 |
+
|
170 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-01.jpg?raw=true" alt="0.4.2-whatsnew-01" width="100%"/>
|
171 |
+
|
172 |
+
Result example (the new face was created from 4 faces of different actresses):
|
173 |
+
|
174 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-02.jpg?raw=true" alt="0.4.2-whatsnew-02" width="75%"/>
|
175 |
+
|
176 |
+
Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json)
|
177 |
+
|
178 |
+
### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1)
|
179 |
+
|
180 |
+
- CUDA 12 Support - don't forget to run (Windows) `install.bat` or (Linux/MacOS) `install.py` for ComfyUI's Python enclosure or try to install ORT-GPU for CU12 manually (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
|
181 |
+
- Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173 fix
|
182 |
+
|
183 |
+
- Separate Node for the Face Restoration postprocessing (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), can be found inside ReActor's menu (RestoreFace Node)
|
184 |
+
- (Windows) Installation can be done for Python from the System's PATH
|
185 |
+
- Different fixes and improvements
|
186 |
+
|
187 |
+
- Face Restore Visibility and CodeFormer Weight (Fidelity) options are now available! Don't forget to reload the Node in your existing workflow
|
188 |
+
|
189 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.1-whatsnew-01.jpg?raw=true" alt="0.4.1-whatsnew-01" width="100%"/>
|
190 |
+
|
191 |
+
### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0)
|
192 |
+
|
193 |
+
- Input "input_image" goes first now, it gives a correct bypass and also it is right to have the main input first;
|
194 |
+
- You can now save face models as "safetensors" files (`ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use:
|
195 |
+
|
196 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-01.jpg?raw=true" alt="0.4.0-whatsnew-01" width="100%"/>
|
197 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-02.jpg?raw=true" alt="0.4.0-whatsnew-02" width="100%"/>
|
198 |
+
|
199 |
+
- Ability to build and save face models directly from an image:
|
200 |
+
|
201 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-03.jpg?raw=true" alt="0.4.0-whatsnew-03" width="50%"/>
|
202 |
+
|
203 |
+
- Both the inputs are optional, just connect one of them according to your workflow; if both is connected - `image` has a priority.
|
204 |
+
- Different fixes making this extension better.
|
205 |
+
|
206 |
+
Thanks to everyone who finds bugs, suggests new features and supports this project!
|
207 |
+
|
208 |
+
</details>
|
209 |
+
|
210 |
+
## Installation
|
211 |
+
|
212 |
+
<details>
|
213 |
+
<summary>SD WebUI: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/">AUTOMATIC1111</a> or <a href="https://github.com/vladmandic/automatic">SD.Next</a></summary>
|
214 |
+
|
215 |
+
1. Close (stop) your SD-WebUI/Comfy Server if it's running
|
216 |
+
2. (For Windows Users):
|
217 |
+
- Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface)
|
218 |
+
- OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile"
|
219 |
+
- OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild)
|
220 |
+
3. Go to the `extensions\sd-webui-comfyui\ComfyUI\custom_nodes`
|
221 |
+
4. Open Console or Terminal and run `git clone https://github.com/Gourieff/ComfyUI-ReActor`
|
222 |
+
5. Go to the SD WebUI root folder, open Console or Terminal and run (Windows users)`.\venv\Scripts\activate` or (Linux/MacOS)`venv/bin/activate`
|
223 |
+
6. `python -m pip install -U pip`
|
224 |
+
7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor`
|
225 |
+
8. `python install.py`
|
226 |
+
9. Please, wait until the installation process will be finished
|
227 |
+
10. (From the version 0.3.0) Download additional facerestorers models from the link below and put them into the `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models` directory:<br>
|
228 |
+
https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
|
229 |
+
11. Run SD WebUI and check console for the message that ReActor Node is running:
|
230 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/console_status_running.jpg?raw=true" alt="console_status_running" width="759"/>
|
231 |
+
|
232 |
+
12. Go to the ComfyUI tab and find there ReActor Node inside the menu `ReActor` or by using a search:
|
233 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/webui-demo.png?raw=true" alt="webui-demo" width="100%"/>
|
234 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/search-demo.png?raw=true" alt="webui-demo" width="1043"/>
|
235 |
+
|
236 |
+
</details>
|
237 |
+
|
238 |
+
<details>
|
239 |
+
<summary>Standalone (Portable) <a href="https://github.com/comfyanonymous/ComfyUI">ComfyUI</a> for Windows</summary>
|
240 |
+
|
241 |
+
1. Do the following:
|
242 |
+
- Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface)
|
243 |
+
- OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile"
|
244 |
+
- OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild)
|
245 |
+
2. Choose between two options:
|
246 |
+
- (ComfyUI Manager) Open ComfyUI Manager, click "Install Custom Nodes", type "ReActor" in the "Search" field and then click "Install". After ComfyUI will complete the process - please restart the Server.
|
247 |
+
- (Manually) Go to `ComfyUI\custom_nodes`, open Console and run `git clone https://github.com/Gourieff/ComfyUI-ReActor`
|
248 |
+
3. Go to `ComfyUI\custom_nodes\ComfyUI-ReActor` and run `install.bat`
|
249 |
+
4. If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory<br>As well as one or both of "Sams" models from [here](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - download (if you don't have them) and put into the "ComfyUI\models\sams" directory
|
250 |
+
5. Run ComfyUI and find there ReActor Nodes inside the menu `ReActor` or by using a search
|
251 |
+
|
252 |
+
</details>
|
253 |
+
|
254 |
+
## Usage
|
255 |
+
|
256 |
+
You can find ReActor Nodes inside the menu `ReActor` or by using a search (just type "ReActor" in the search field)
|
257 |
+
|
258 |
+
List of Nodes:
|
259 |
+
- ••• Main Nodes •••
|
260 |
+
- ReActorFaceSwap (Main Node)
|
261 |
+
- ReActorFaceSwapOpt (Main Node with the additional Options input)
|
262 |
+
- ReActorOptions (Options for ReActorFaceSwapOpt)
|
263 |
+
- ReActorFaceBoost (Face Booster Node)
|
264 |
+
- ReActorMaskHelper (Masking Helper)
|
265 |
+
- ••• Operations with Face Models •••
|
266 |
+
- ReActorSaveFaceModel (Save Face Model)
|
267 |
+
- ReActorLoadFaceModel (Load Face Model)
|
268 |
+
- ReActorBuildFaceModel (Build Blended Face Model)
|
269 |
+
- ReActorMakeFaceModelBatch (Make Face Model Batch)
|
270 |
+
- ••• Additional Nodes •••
|
271 |
+
- ReActorRestoreFace (Face Restoration)
|
272 |
+
- ReActorImageDublicator (Dublicate one Image to Images List)
|
273 |
+
- ImageRGBA2RGB (Convert RGBA to RGB)
|
274 |
+
|
275 |
+
Connect all required slots and run the query.
|
276 |
+
|
277 |
+
### Main Node Inputs
|
278 |
+
|
279 |
+
- `input_image` - is an image to be processed (target image, analog of "target image" in the SD WebUI extension);
|
280 |
+
- Supported Nodes: "Load Image", "Load Video" or any other nodes providing images as an output;
|
281 |
+
- `source_image` - is an image with a face or faces to swap in the `input_image` (source image, analog of "source image" in the SD WebUI extension);
|
282 |
+
- Supported Nodes: "Load Image" or any other nodes providing images as an output;
|
283 |
+
- `face_model` - is the input for the "Load Face Model" Node or another ReActor node to provide a face model file (face embedding) you created earlier via the "Save Face Model" Node;
|
284 |
+
- Supported Nodes: "Load Face Model", "Build Blended Face Model";
|
285 |
+
|
286 |
+
### Main Node Outputs
|
287 |
+
|
288 |
+
- `IMAGE` - is an output with the resulted image;
|
289 |
+
- Supported Nodes: any nodes which have images as an input;
|
290 |
+
- `FACE_MODEL` - is an output providing a source face's model being built during the swapping process;
|
291 |
+
- Supported Nodes: "Save Face Model", "ReActor", "Make Face Model Batch";
|
292 |
+
|
293 |
+
### Face Restoration
|
294 |
+
|
295 |
+
Since version 0.3.0 ReActor Node has a buil-in face restoration.<br>Just download the models you want (see [Installation](#installation) instruction) and select one of them to restore the resulting face(s) during the faceswap. It will enhance face details and make your result more accurate.
|
296 |
+
|
297 |
+
### Face Indexes
|
298 |
+
|
299 |
+
By default ReActor detects faces in images from "large" to "small".<br>You can change this option by adding ReActorFaceSwapOpt node with ReActorOptions.
|
300 |
+
|
301 |
+
And if you need to specify faces, you can set indexes for source and input images.
|
302 |
+
|
303 |
+
Index of the first detected face is 0.
|
304 |
+
|
305 |
+
You can set indexes in the order you need.<br>
|
306 |
+
E.g.: 0,1,2 (for Source); 1,0,2 (for Input).<br>This means: the second Input face (index = 1) will be swapped by the first Source face (index = 0) and so on.
|
307 |
+
|
308 |
+
### Genders
|
309 |
+
|
310 |
+
You can specify the gender to detect in images.<br>
|
311 |
+
ReActor will swap a face only if it meets the given condition.
|
312 |
+
|
313 |
+
### Face Models
|
314 |
+
|
315 |
+
Since version 0.4.0 you can save face models as "safetensors" files (stored in `ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use.
|
316 |
+
|
317 |
+
To make new models appear in the list of the "Load Face Model" Node - just refresh the page of your ComfyUI web application.<br>
|
318 |
+
(I recommend you to use ComfyUI Manager - otherwise you workflow can be lost after you refresh the page if you didn't save it before that).
|
319 |
+
|
320 |
+
## Troubleshooting
|
321 |
+
|
322 |
+
<a name="insightfacebuild">
|
323 |
+
|
324 |
+
### **I. (For Windows users) If you still cannot build Insightface for some reasons or just don't want to install Visual Studio or VS C++ Build Tools - do the following:**
|
325 |
+
|
326 |
+
1. (ComfyUI Portable) From the root folder check the version of Python:<br>run CMD and type `python_embeded\python.exe -V`
|
327 |
+
2. Download prebuilt Insightface package [for Python 3.10](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp310-cp310-win_amd64.whl) or [for Python 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (if in the previous step you see 3.11) or [for Python 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (if in the previous step you see 3.12) and put into the stable-diffusion-webui (A1111 or SD.Next) root folder (where you have "webui-user.bat" file) or into ComfyUI root folder if you use ComfyUI Portable
|
328 |
+
3. From the root folder run:
|
329 |
+
- (SD WebUI) CMD and `.\venv\Scripts\activate`
|
330 |
+
- (ComfyUI Portable) run CMD
|
331 |
+
4. Then update your PIP:
|
332 |
+
- (SD WebUI) `python -m pip install -U pip`
|
333 |
+
- (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip`
|
334 |
+
5. Then install Insightface:
|
335 |
+
- (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
|
336 |
+
- (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
|
337 |
+
6. Enjoy!
|
338 |
+
|
339 |
+
### **II. "AttributeError: 'NoneType' object has no attribute 'get'"**
|
340 |
+
|
341 |
+
This error may occur if there's smth wrong with the model file `inswapper_128.onnx`
|
342 |
+
|
343 |
+
Try to download it manually from [here](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx)
|
344 |
+
and put it to the `ComfyUI\models\insightface` replacing existing one
|
345 |
+
|
346 |
+
### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"**
|
347 |
+
|
348 |
+
This means that input points have been changed with the latest update<br>
|
349 |
+
Remove the current ReActor Node from your workflow and add it again
|
350 |
+
|
351 |
+
### **IV. ControlNet Aux Node IMPORT failed error when using with ReActor Node**
|
352 |
+
|
353 |
+
1. Close ComfyUI if it runs
|
354 |
+
2. Go to the ComfyUI root folder, open CMD there and run:
|
355 |
+
- `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless`
|
356 |
+
- `python_embeded\python.exe -m pip install opencv-python==4.7.0.72`
|
357 |
+
3. That's it!
|
358 |
+
|
359 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/reactor-w-controlnet.png?raw=true" alt="reactor+controlnet" />
|
360 |
+
|
361 |
+
### **V. "ModuleNotFoundError: No module named 'basicsr'" or "subprocess-exited-with-error" during future-0.18.3 installation**
|
362 |
+
|
363 |
+
- Download https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl<br>
|
364 |
+
- Put it to ComfyUI root And run:
|
365 |
+
|
366 |
+
python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl
|
367 |
+
|
368 |
+
- Then:
|
369 |
+
|
370 |
+
python_embeded\python.exe -m pip install basicsr
|
371 |
+
|
372 |
+
### **VI. "fatal: fetch-pack: invalid index-pack output" when you try to `git clone` the repository"**
|
373 |
+
|
374 |
+
Try to clone with `--depth=1` (last commit only):
|
375 |
+
|
376 |
+
git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor
|
377 |
+
|
378 |
+
Then retrieve the rest (if you need):
|
379 |
+
|
380 |
+
git fetch --unshallow
|
381 |
+
|
382 |
+
## Updating
|
383 |
+
|
384 |
+
Just put .bat or .sh script from this [Repo](https://github.com/Gourieff/sd-webui-extensions-updater) to the `ComfyUI\custom_nodes` directory and run it when you need to check for updates
|
385 |
+
|
386 |
+
### Disclaimer
|
387 |
+
|
388 |
+
This software is meant to be a productive contribution to the rapidly growing AI-generated media industry. It will help artists with tasks such as animating a custom character or using the character as a model for clothing etc.
|
389 |
+
|
390 |
+
The developers of this software are aware of its possible unethical applications and are committed to take preventative measures against them. We will continue to develop this project in the positive direction while adhering to law and ethics.
|
391 |
+
|
392 |
+
Users of this software are expected to use this software responsibly while abiding the local law. If face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. **Developers and Contributors of this software are not responsible for actions of end-users.**
|
393 |
+
|
394 |
+
By using this extension you are agree not to create any content that:
|
395 |
+
- violates any laws;
|
396 |
+
- causes any harm to a person or persons;
|
397 |
+
- propagates (spreads) any information (both public or personal) or images (both public or personal) which could be meant for harm;
|
398 |
+
- spreads misinformation;
|
399 |
+
- targets vulnerable groups of people.
|
400 |
+
|
401 |
+
This software utilizes the pre-trained models `buffalo_l` and `inswapper_128.onnx`, which are provided by [InsightFace](https://github.com/deepinsight/insightface/). These models are included under the following conditions:
|
402 |
+
|
403 |
+
[From insighface license](https://github.com/deepinsight/insightface/tree/master/python-package): The InsightFace’s pre-trained models are available for non-commercial research purposes only. This includes both auto-downloading models and manually downloaded models.
|
404 |
+
|
405 |
+
Users of this software must strictly adhere to these conditions of use. The developers and maintainers of this software are not responsible for any misuse of InsightFace’s pre-trained models.
|
406 |
+
|
407 |
+
Please note that if you intend to use this software for any commercial purposes, you will need to train your own models or find models that can be used commercially.
|
408 |
+
|
409 |
+
### Models Hashsum
|
410 |
+
|
411 |
+
#### Safe-to-use models have the following hash:
|
412 |
+
|
413 |
+
inswapper_128.onnx
|
414 |
+
```
|
415 |
+
MD5:a3a155b90354160350efd66fed6b3d80
|
416 |
+
SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
|
417 |
+
```
|
418 |
+
|
419 |
+
1k3d68.onnx
|
420 |
+
|
421 |
+
```
|
422 |
+
MD5:6fb94fcdb0055e3638bf9158e6a108f4
|
423 |
+
SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
424 |
+
```
|
425 |
+
|
426 |
+
2d106det.onnx
|
427 |
+
|
428 |
+
```
|
429 |
+
MD5:a3613ef9eb3662b4ef88eb90db1fcf26
|
430 |
+
SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
431 |
+
```
|
432 |
+
|
433 |
+
det_10g.onnx
|
434 |
+
|
435 |
+
```
|
436 |
+
MD5:4c10eef5c9e168357a16fdd580fa8371
|
437 |
+
SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
438 |
+
```
|
439 |
+
|
440 |
+
genderage.onnx
|
441 |
+
|
442 |
+
```
|
443 |
+
MD5:81c77ba87ab38163b0dec6b26f8e2af2
|
444 |
+
SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
445 |
+
```
|
446 |
+
|
447 |
+
w600k_r50.onnx
|
448 |
+
|
449 |
+
```
|
450 |
+
MD5:80248d427976241cbd1343889ed132b3
|
451 |
+
SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
|
452 |
+
```
|
453 |
+
|
454 |
+
**Please check hashsums if you download these models from unverified (or untrusted) sources**
|
455 |
+
|
456 |
+
<a name="credits">
|
457 |
+
|
458 |
+
## Thanks and Credits
|
459 |
+
|
460 |
+
<details>
|
461 |
+
<summary><a>Click to expand</a></summary>
|
462 |
+
|
463 |
+
<br>
|
464 |
+
|
465 |
+
|file|source|license|
|
466 |
+
|----|------|-------|
|
467 |
+
|[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) |  |
|
468 |
+
| [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) |  |
|
469 |
+
| [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) |  |
|
470 |
+
| [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) |  |
|
471 |
+
| [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) |  |
|
472 |
+
| [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) |  |
|
473 |
+
|
474 |
+
[BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup) <br>
|
475 |
+
[facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao) <br>
|
476 |
+
|
477 |
+
[@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - the original Roop App <br>
|
478 |
+
[@ssitu](https://github.com/ssitu) - the first version of [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop) extension
|
479 |
+
|
480 |
+
</details>
|
481 |
+
|
482 |
+
<a name="note">
|
483 |
+
|
484 |
+
### Note!
|
485 |
+
|
486 |
+
**If you encounter any errors when you use ReActor Node - don't rush to open an issue, first try to remove current ReActor node in your workflow and add it again**
|
487 |
+
|
488 |
+
**ReActor Node gets updates from time to time, new functions appear and old node can work with errors or not work at all**
|
custom_nodes/ComfyUI-ReActor/README_RU.md
ADDED
@@ -0,0 +1,497 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<img src="https://github.com/Gourieff/Assets/raw/main/sd-webui-reactor/ReActor_logo_NEW_RU.png?raw=true" alt="logo" width="180px"/>
|
4 |
+
|
5 |
+

|
6 |
+
|
7 |
+
<!--<sup>
|
8 |
+
<font color=brightred>
|
9 |
+
|
10 |
+
## !!! [Важные изменения](#latestupdate) !!!<br>Не забудьте добавить Нод заново в существующие воркфлоу
|
11 |
+
|
12 |
+
</font>
|
13 |
+
</sup>-->
|
14 |
+
|
15 |
+
<a href="https://boosty.to/artgourieff" target="_blank">
|
16 |
+
<img src="https://lovemet.ru/img/boosty.jpg" width="108" alt="Поддержать проект на Boosty"/>
|
17 |
+
<br>
|
18 |
+
<sup>
|
19 |
+
Поддержать проект
|
20 |
+
</sup>
|
21 |
+
</a>
|
22 |
+
|
23 |
+
<hr>
|
24 |
+
|
25 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/commits/main)
|
26 |
+

|
27 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0)
|
28 |
+
[](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed)
|
29 |
+

|
30 |
+
|
31 |
+
[English](/README.md) | Русский
|
32 |
+
|
33 |
+
# ReActor Nodes для ComfyUI<br><sub><sup>-=Безопасно для работы | SFW-Friendly=-</sup></sub>
|
34 |
+
|
35 |
+
</div>
|
36 |
+
|
37 |
+
### Ноды (nodes) для быстрой и простой замены лиц на любых изображениях для работы с ComfyUI, основан на [ранее заблокированном РеАкторе](https://github.com/Gourieff/comfyui-reactor-node) - теперь имеется встроенный NSFW-детектор, исключающий замену лиц на изображениях с контентом 18+
|
38 |
+
|
39 |
+
> Используя данное ПО, вы понимаете и принимаете [ответственность](#disclaimer)
|
40 |
+
|
41 |
+
<div align="center">
|
42 |
+
|
43 |
+
---
|
44 |
+
[**Что нового**](#latestupdate) | [**Установка**](#installation) | [**Использование**](#usage) | [**Устранение проблем**](#troubleshooting) | [**Обновление**](#updating) | [**Ответственность**](#disclaimer) | [**Благодарности**](#credits) | [**Заметка**](#note)
|
45 |
+
|
46 |
+
---
|
47 |
+
|
48 |
+
</div>
|
49 |
+
|
50 |
+
<a name="latestupdate">
|
51 |
+
|
52 |
+
## Что нового в последнем обновлении
|
53 |
+
|
54 |
+
### 0.6.0 <sub><sup>ALPHA1</sup></sub>
|
55 |
+
|
56 |
+
- Новый нод `ReActorSetWeight` - теперь можно установить силу замены лица для `source_image` или `face_model` от 0% до 100% (с шагом 12.5%)
|
57 |
+
|
58 |
+
<center>
|
59 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-01.jpg?raw=true" alt="0.6.0-whatsnew-01" width="100%"/>
|
60 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-02.jpg?raw=true" alt="0.6.0-whatsnew-02" width="100%"/>
|
61 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-alpha1-01.gif?raw=true" alt="0.6.0-whatsnew-03" width="540px"/>
|
62 |
+
</center>
|
63 |
+
|
64 |
+
<details>
|
65 |
+
<summary><a>Предыдущие версии</a></summary>
|
66 |
+
|
67 |
+
### 0.5.2
|
68 |
+
|
69 |
+
- Поддержка моделей ReSwapper. Несмотря на то, что Inswapper по-прежнему даёт лучшее сходство, но ReSwapper развивается - спасибо @somanchiu https://github.com/somanchiu/ReSwapper за эти модели и проект ReSwapper! Это хороший шаг для Сообщества в создании альтернативы Инсваппера!
|
70 |
+
|
71 |
+
<center>
|
72 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-03.jpg?raw=true" alt="0.5.2-whatsnew-03" width="75%"/>
|
73 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-04.jpg?raw=true" alt="0.5.2-whatsnew-04" width="75%"/>
|
74 |
+
</center>
|
75 |
+
|
76 |
+
Скачать модели ReSwapper можно отсюда:
|
77 |
+
https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models
|
78 |
+
Сохраните их в директорию "models/reswapper".
|
79 |
+
|
80 |
+
- NSFW-детектор, чтобы не нарушать [правила GitHub](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools)
|
81 |
+
- Новый нод "Unload ReActor Models" - полезен для сложных воркфлоу, когда вам нужно освободить ОЗУ, занятую РеАктором
|
82 |
+
|
83 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-01.jpg?raw=true" alt="0.5.2-whatsnew-01" width="100%"/>
|
84 |
+
|
85 |
+
- Поддержка ORT CoreML and ROCM EPs, достаточно установить ту версию onnxruntime, которая соответствует вашему GPU
|
86 |
+
- Некоторые улучшения скрипта установки для поддержки последней версии ORT-GPU
|
87 |
+
|
88 |
+
<center>
|
89 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-02.jpg?raw=true" alt="0.5.2-whatsnew-02" width="50%"/>
|
90 |
+
</center>
|
91 |
+
|
92 |
+
- Исправления и улучшения
|
93 |
+
|
94 |
+
### 0.5.1
|
95 |
+
|
96 |
+
- Поддержка моделей восстановления лиц GPEN 1024/2048 (доступны в датасете на HF https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models)
|
97 |
+
- Нод ReActorFaceBoost - попытка улучшить качество заменённых лиц. Идея состоит в том, чтобы восстановить и увеличить заменённое лицо (в соответствии с параметром `face_size` модели реставрации) ДО того, как лицо будет вставлено в целевое изображения (через алгоритмы инсваппера), больше информации [здесь (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321)
|
98 |
+
|
99 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-01.jpg?raw=true" alt="0.5.1-whatsnew-01" width="100%"/>
|
100 |
+
|
101 |
+
[Полноразмерное демо-превью](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png)
|
102 |
+
|
103 |
+
- Сортировка моделей лиц по алфавиту
|
104 |
+
- Множество исправлений и улучшений
|
105 |
+
|
106 |
+
### [0.5.0 <sub><sup>BETA4</sup></sub>](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0)
|
107 |
+
|
108 |
+
- Поддержка библиотеки Spandrel при работе с GFPGAN
|
109 |
+
|
110 |
+
### 0.5.0 <sub><sup>BETA3</sup></sub>
|
111 |
+
|
112 |
+
- Исправления: "RAM issue", "No detection" для MaskingHelper
|
113 |
+
|
114 |
+
### 0.5.0 <sub><sup>BETA2</sup></sub>
|
115 |
+
|
116 |
+
- Появилась возможность строить смешанные модели лиц из пачки уже имеющихся моделей - добавьте для этого нод "Make Face Model Batch" в свой воркфлоу и загрузите несколько моделей через ноды "Load Face Model"
|
117 |
+
- Огромный буст производительности модуля анализа изображений! 10-кратный прирост скорости! Работа с видео теперь в удовольствие!
|
118 |
+
|
119 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-05.png?raw=true" alt="0.5.0-whatsnew-05" width="100%"/>
|
120 |
+
|
121 |
+
### 0.5.0 <sub><sup>BETA1</sup></sub>
|
122 |
+
|
123 |
+
- Добавлен выход SWAPPED_FACE для нода Masking Helper
|
124 |
+
- FIX: Удалён пустой A-канал на выходе IMAGE нода Masking Helper (вызывавший ошибки с некоторым нодами)
|
125 |
+
|
126 |
+
### 0.5.0 <sub><sup>ALPHA1</sup></sub>
|
127 |
+
|
128 |
+
- Нод ReActorBuildFaceModel получил выход "face_model" для отправки совмещенной модели лиц непосредственно в основной Нод:
|
129 |
+
|
130 |
+
Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json)
|
131 |
+
|
132 |
+
- Функции маски лица теперь доступна и в версии для Комфи, просто добавьте нод "ReActorMaskHelper" в воркфлоу и соедините узлы, как показано ниже:
|
133 |
+
|
134 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-01.jpg?raw=true" alt="0.5.0-whatsnew-01" width="100%"/>
|
135 |
+
|
136 |
+
Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox"
|
137 |
+
<br>
|
138 |
+
То же самое и с ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) - скачайте (если отсутствует) и положите в папку "ComfyUI\models\sams";
|
139 |
+
|
140 |
+
Данный нод поможет вам получить куда более аккуратный результат при замене лиц:
|
141 |
+
|
142 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-02.jpg?raw=true" alt="0.5.0-whatsnew-02" width="100%"/>
|
143 |
+
|
144 |
+
- Нод ReActorImageDublicator - полезен тем, кто создает видео, помогает продублировать одиночное изображение в несколько копий, чтобы использовать их, к при��еру, с VAE энкодером:
|
145 |
+
|
146 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-03.jpg?raw=true" alt="0.5.0-whatsnew-03" width="100%"/>
|
147 |
+
|
148 |
+
- ReActorFaceSwapOpt (упрощенная версия основного нода) + нод ReActorOptions для установки дополнительных опций, как (новые) "отдельный порядок лиц для input/source". Да! Теперь можно установить любой порядок "чтения" индекса лиц на изображении, в т.ч. от большего к меньшему (по умолчанию)!
|
149 |
+
|
150 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-04.jpg?raw=true" alt="0.5.0-whatsnew-04" width="100%"/>
|
151 |
+
|
152 |
+
- Небольшое улучшение скорости анализа целевых изображений (input)
|
153 |
+
|
154 |
+
### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2)
|
155 |
+
|
156 |
+
- Добавлена поддержка GPEN-BFR-512 и RestoreFormer_Plus_Plus моделей восстановления лиц
|
157 |
+
|
158 |
+
Скачать можно здесь: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
|
159 |
+
<br>Добавьте модели в папку `ComfyUI\models\facerestore_models`
|
160 |
+
|
161 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-04.jpg?raw=true" alt="0.4.2-whatsnew-04" width="100%"/>
|
162 |
+
|
163 |
+
- По многочисленным просьбам появилась возможность строить смешанные модели лиц и в ComfyUI тоже и использовать их с нодом "Load Face Model" Node или в SD WebUI;
|
164 |
+
|
165 |
+
Экспериментируйте и создавайте новые лица или совмещайте разные лица нужного вам персонажа, чтобы добиться лучшей точности и схожести с оригиналом!
|
166 |
+
|
167 |
+
Достаточно добавить нод "Make Image Batch" (ImpactPack) на вход нового нода РеАктора и загрузить в пачку необходимые вам изображения для построения смешанной модели:
|
168 |
+
|
169 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-01.jpg?raw=true" alt="0.4.2-whatsnew-01" width="100%"/>
|
170 |
+
|
171 |
+
Пример результата (на основе лиц 4-х актрис создано новое лицо):
|
172 |
+
|
173 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-02.jpg?raw=true" alt="0.4.2-whatsnew-02" width="75%"/>
|
174 |
+
|
175 |
+
Базовый воркфлоу [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json)
|
176 |
+
|
177 |
+
### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1)
|
178 |
+
|
179 |
+
- Поддержка CUDA 12 - не забудьте запустить (Windows) `install.bat` или (Linux/MacOS) `install.py` для используемого Python окружения или попробуйте установить ORT-GPU для CU12 вручную (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
|
180 |
+
- Исправление Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173
|
181 |
+
|
182 |
+
- Отдельный Нод для восстаноления лиц (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), располагается внутри меню ReActor (нод RestoreFace)
|
183 |
+
- (Windows) Установка зависимостей теперь может быть выполнена в Python из PATH ОС
|
184 |
+
- Разные исправления и улучшения
|
185 |
+
|
186 |
+
- Face Restore Visibility и CodeFormer Weight (Fidelity) теперь доступны; не забудьте заново добавить Нод в ваших существующих воркфлоу
|
187 |
+
|
188 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.1-whatsnew-01.jpg?raw=true" alt="0.4.1-whatsnew-01" width="100%"/>
|
189 |
+
|
190 |
+
### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0)
|
191 |
+
|
192 |
+
- Вход "input_image" теперь идёт первым, это даёт возможность корректного байпаса, а также это правильно с точки зрения расположения входов (главный вход - первый);
|
193 |
+
- Теперь можно сохранять модели лиц в качестве файлов "safetensors" (`ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете:
|
194 |
+
|
195 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-01.jpg?raw=true" alt="0.4.0-whatsnew-01" width="100%"/>
|
196 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-02.jpg?raw=true" alt="0.4.0-whatsnew-02" width="100%"/>
|
197 |
+
|
198 |
+
- Возможность сохранять модели лиц напрямую из изображения:
|
199 |
+
|
200 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-03.jpg?raw=true" alt="0.4.0-whatsnew-03" width="50%"/>
|
201 |
+
|
202 |
+
- Оба входа опциональны, присоедините один из них в соответствии с вашим воркфлоу; если присоеденены оба - вход `image` имеет приоритет.
|
203 |
+
- Различные исправления, делающие это расширение лучше.
|
204 |
+
|
205 |
+
Спасибо всем, кто находит ошибки, предлагает новые функции и поддерживает данный проект!
|
206 |
+
|
207 |
+
</details>
|
208 |
+
|
209 |
+
<a name="installation">
|
210 |
+
|
211 |
+
## Установка
|
212 |
+
|
213 |
+
<details>
|
214 |
+
<summary>SD WebUI: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/">AUTOMATIC1111</a> или <a href="https://github.com/vladmandic/automatic">SD.Next</a></summary>
|
215 |
+
|
216 |
+
1. Закройте (остановите) SD-WebUI Сервер, если запущен
|
217 |
+
2. (Для пользователей Windows):
|
218 |
+
- Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface)
|
219 |
+
- ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile"
|
220 |
+
- ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild)
|
221 |
+
3. Перейдите в `extensions\sd-webui-comfyui\ComfyUI\custom_nodes`
|
222 |
+
4. Откройте Консоль или Терминал и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor`
|
223 |
+
5. Перейдите в корневую директорию SD WebUI, откройте Консоль или Терминал и выполните (для пользователей Windows)`.\venv\Scripts\activate` или (для пользователей Linux/MacOS)`venv/bin/activate`
|
224 |
+
6. `python -m pip install -U pip`
|
225 |
+
7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor`
|
226 |
+
8. `python install.py`
|
227 |
+
9. Пожалуйста, дождитесь полного завершения установки
|
228 |
+
10. (Начиная с версии 0.3.0) Скачайте дополнительные модели восстановления лиц (по ссылке ниже) и сохраните их в папку `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models`:<br>
|
229 |
+
https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
|
230 |
+
11. Запустите SD WebUI и проверьте консоль на сообщение, что ReActor Node работает:
|
231 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/console_status_running.jpg?raw=true" alt="console_status_running" width="759"/>
|
232 |
+
|
233 |
+
12. Перейдите во вкладку ComfyUI и найдите там ReActor Node внутри меню `ReActor` или через поиск:
|
234 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/webui-demo.png?raw=true" alt="webui-demo" width="100%"/>
|
235 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/search-demo.png?raw=true" alt="webui-demo" width="1043"/>
|
236 |
+
|
237 |
+
</details>
|
238 |
+
|
239 |
+
<details>
|
240 |
+
<summary>Портативная версия <a href="https://github.com/comfyanonymous/ComfyUI">ComfyUI</a> для Windows</summary>
|
241 |
+
|
242 |
+
1. Сделайте следующее:
|
243 |
+
- Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface)
|
244 |
+
- ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile"
|
245 |
+
- ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild)
|
246 |
+
2. Выберите из двух вариантов:
|
247 |
+
- (ComfyUI Manager) Откройте ComfyUI Manager, нажвите "Install Custom Nodes", введите "ReActor" в поле "Search" и далее нажмите "Install". После того, как ComfyUI завершит установку, перезагрузите сервер.
|
248 |
+
- (Вручную) Перейди��е в `ComfyUI\custom_nodes`, откройте Консоль и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor`
|
249 |
+
3. Перейдите `ComfyUI\custom_nodes\ComfyUI-ReActor` и запустите `install.bat`, дождитесь окончания установки
|
250 |
+
4. Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox"<br>
|
251 |
+
То же самое и с "Sams" моделями, скачайте одну или обе [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - и положите в папку "ComfyUI\models\sams"
|
252 |
+
5. Запустите ComfyUI и найдите ReActor Node внутри меню `ReActor` или через поиск
|
253 |
+
|
254 |
+
</details>
|
255 |
+
|
256 |
+
<a name="usage">
|
257 |
+
|
258 |
+
## Использование
|
259 |
+
|
260 |
+
Вы можете найти ноды ReActor внутри меню `ReActor` или через поиск (достаточно ввести "ReActor" в поисковой строке)
|
261 |
+
|
262 |
+
Список нодов:
|
263 |
+
- ••• Main Nodes •••
|
264 |
+
- ReActorFaceSwap (Основной нод)
|
265 |
+
- ReActorFaceSwapOpt (Основной нод с доп. входом Options)
|
266 |
+
- ReActorOptions (Опции для ReActorFaceSwapOpt)
|
267 |
+
- ReActorFaceBoost (Нод Face Booster)
|
268 |
+
- ReActorMaskHelper (Masking Helper)
|
269 |
+
- ••• Operations with Face Models •••
|
270 |
+
- ReActorSaveFaceModel (Save Face Model)
|
271 |
+
- ReActorLoadFaceModel (Load Face Model)
|
272 |
+
- ReActorBuildFaceModel (Build Blended Face Model)
|
273 |
+
- ReActorMakeFaceModelBatch (Make Face Model Batch)
|
274 |
+
- ••• Additional Nodes •••
|
275 |
+
- ReActorRestoreFace (Face Restoration)
|
276 |
+
- ReActorImageDublicator (Dublicate one Image to Images List)
|
277 |
+
- ImageRGBA2RGB (Convert RGBA to RGB)
|
278 |
+
|
279 |
+
Соедините все необходимые слоты (slots) и запустите очередь (query).
|
280 |
+
|
281 |
+
### Входы основного Нода
|
282 |
+
|
283 |
+
- `input_image` - это изображение, на котором надо поменять лицо или лица (целевое изображение, аналог "target image" в версии для SD WebUI);
|
284 |
+
- Поддерживаемые ноды: "Load Image", "Load Video" или любые другие ноды предоставляющие изображение в качестве выхода;
|
285 |
+
- `source_image` - это изображение с лицом или лицами для замены (изображение-источник, аналог "source image" в версии для SD WebUI);
|
286 |
+
- Поддерживаемые ноды: "Load Image" или любые другие ноды с выходом Image(s);
|
287 |
+
- `face_model` - это вход для выхода с нода "Load Face Model" или другого нода ReActor для загрузки модели лица (face model или face embedding), которое вы создали ранее через нод "Save Face Model";
|
288 |
+
- Поддерживаемые ноды: "Load Face Model", "Build Blended Face Model";
|
289 |
+
|
290 |
+
### Выходы основного Нода
|
291 |
+
|
292 |
+
- `IMAGE` - выход с готовым изображением (результатом);
|
293 |
+
- Поддерживаемые ноды: любые ноды с изображением на входе;
|
294 |
+
- `FACE_MODEL` - выход, предоставляющий модель лица, построенную в ходе замены;
|
295 |
+
- Поддерживаемые ноды: "Save Face Model", "ReActor", "Make Face Model Batch";
|
296 |
+
|
297 |
+
### Восстановление лиц
|
298 |
+
|
299 |
+
Начиная с версии 0.3.0 ReActor Node имеет встроенное восстановление лиц.<br>Скачайте нужные вам модели (см. инструкцию по [Установке](#installation)) и выберите одну из них, чтобы улучшить качество финального лица.
|
300 |
+
|
301 |
+
### Индексы Лиц (Face Indexes)
|
302 |
+
|
303 |
+
По умолчанию ReActor определяет лица на изображении в порядке от "большого" к "малому".<br>Вы можете поменять эту опцию, используя нод ReActorFaceSwapOpt вместе с ReActorOptions.
|
304 |
+
|
305 |
+
Если вам нужно заменить определенное лицо, вы можете указать индекс для исходного (source, с лицом) и входного (input, где будет замена лица) изображений.
|
306 |
+
|
307 |
+
Индекс первого обнаруженного лица - 0.
|
308 |
+
|
309 |
+
Вы можете задать индексы в том порядке, который вам нужен.<br>
|
310 |
+
Например: 0,1,2 (для Source); 1,0,2 (для Input).<br>Это означает, что: второе лицо из Input (индекс = 1) будет заменено первым лицом из Source (индекс = 0) и так далее.
|
311 |
+
|
312 |
+
### Определение Пола
|
313 |
+
|
314 |
+
Вы можете обозначить, какой пол нужно определять на изображении.<br>
|
315 |
+
ReActor заменит только то лицо, которое удовлетворяет заданному условию.
|
316 |
+
|
317 |
+
### Модели Лиц
|
318 |
+
Начиная с версии 0.4.0, вы можете сохранять модели лиц как файлы "safetensors" (хранятся в папке `ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете.
|
319 |
+
|
320 |
+
Чтобы новые модели появились в списке моделей нода "Load Face Model" - обновите страницу of с ComfyUI.<br>
|
321 |
+
(Рекомендую использовать ComfyUI Manager - иначе ваше воркфлоу может быть потеряно после перезагрузки страницы, если вы не сохранили его).
|
322 |
+
|
323 |
+
<a name="troubleshooting">
|
324 |
+
|
325 |
+
## Устранение проблем
|
326 |
+
|
327 |
+
<a name="insightfacebuild">
|
328 |
+
|
329 |
+
### **I. (Для пользователей Windows) Если вы до сих пор не можете установить пакет Insightface по каким-то причинам или же просто не желаете устанавливать Visual Studio или VS C++ Build Tools - сделайте следующее:**
|
330 |
+
|
331 |
+
1. (ComfyUI Portable) Находясь в корневой директории, проверьте версию Python:<br>запустите CMD и выполните `python_embeded\python.exe -V`<br>Вы должны увидеть версию или 3.10, или 3.11, или 3.12
|
332 |
+
2. Скачайте готовый пакет Insightface [для версии 3.10](https://github.com/Gourieff/sd-webui-reactor/raw/main/example/insightface-0.7.3-cp310-cp310-win_amd64.whl) или [для 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (если на предыдущем шаге вы увидели 3.11) или [для 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (если на предыдущем шаге вы увидели 3.12) и сохраните его в корневую директорию stable-diffusion-webui (A1111 или SD.Next) - туда, где лежит файл "webui-user.bat" -ИЛИ- в корневую директорию ComfyUI, если вы используете ComfyUI Portable
|
333 |
+
3. Из корневой директории запустите:
|
334 |
+
- (SD WebUI) CMD и `.\venv\Scripts\activate`
|
335 |
+
- (ComfyUI Portable) CMD
|
336 |
+
4. Обновите PIP:
|
337 |
+
- (SD WebUI) `python -m pip install -U pip`
|
338 |
+
- (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip`
|
339 |
+
5. Затем установите Insightface:
|
340 |
+
- (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
|
341 |
+
- (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
|
342 |
+
6. Готово!
|
343 |
+
|
344 |
+
### **II. "AttributeError: 'NoneType' object has no attribute 'get'"**
|
345 |
+
|
346 |
+
Эта ошибка появляется, если что-то не так с файлом модели `inswapper_128.onnx`
|
347 |
+
|
348 |
+
Скачайте вручную по ссылке [отсюда](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx)
|
349 |
+
и сохраните в директорию `ComfyUI\models\insightface`, заменив имеющийся файл
|
350 |
+
|
351 |
+
### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"**
|
352 |
+
|
353 |
+
Это означает, что поменялось обозначение входных точек (input points) всвязи с последним обновлением<br>
|
354 |
+
Удалите из вашего рабочего пространства имеющийся ReActor Node и добавьте его снова
|
355 |
+
|
356 |
+
### **IV. ControlNet Aux Node IMPORT failed - при использовании совместно с нодом ReActor**
|
357 |
+
|
358 |
+
1. Закройте или остановите ComfyUI сервер, если он запущен
|
359 |
+
2. Перейдите в корневую папку ComfyUI, откройте консоль CMD и выполните следующее:
|
360 |
+
- `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless`
|
361 |
+
- `python_embeded\python.exe -m pip install opencv-python==4.7.0.72`
|
362 |
+
3. Готово!
|
363 |
+
|
364 |
+
<img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/reactor-w-controlnet.png?raw=true" alt="reactor+controlnet" />
|
365 |
+
|
366 |
+
### **V. "ModuleNotFoundError: No module named 'basicsr'" или "subprocess-exited-with-error" при установке пакета future-0.18.3**
|
367 |
+
|
368 |
+
- Скачайте https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl<br>
|
369 |
+
- Скопируйте файл в корневую папку ComfyUI и выполните в консоли:
|
370 |
+
|
371 |
+
python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl
|
372 |
+
|
373 |
+
- Затем:
|
374 |
+
|
375 |
+
python_embeded\python.exe -m pip install basicsr
|
376 |
+
|
377 |
+
### **VI. "fatal: fetch-pack: invalid index-pack output" при исполнении команды `git clone`"**
|
378 |
+
|
379 |
+
Попробуйте клонировать репозиторий с параметром `--depth=1` (только последний коммит):
|
380 |
+
|
381 |
+
git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor
|
382 |
+
|
383 |
+
Затем вытяните оставшееся (если требуется):
|
384 |
+
|
385 |
+
git fetch --unshallow
|
386 |
+
|
387 |
+
<a name="updating">
|
388 |
+
|
389 |
+
## Обновление
|
390 |
+
|
391 |
+
Положите .bat или .sh скрипт из [данного репозитория](https://github.com/Gourieff/sd-webui-extensions-updater) в папку `ComfyUI\custom_nodes` и запустите, когда желаете обновить ComfyUI и Ноды
|
392 |
+
|
393 |
+
<a name="disclaimer">
|
394 |
+
|
395 |
+
## Ответственность
|
396 |
+
|
397 |
+
Это программное обеспечение призвано стать продуктивным вкладом в быстрорастущую медиаиндустрию на основе генеративных сетей и искусственного интеллекта. Данное ПО поможет художникам в решении таких задач, как анимация собственного персонажа или использование персонажа в качестве модели для одежды и т.д.
|
398 |
+
|
399 |
+
Разработчики этого программного обеспечения осведомлены о возможных неэтичных применениях и обязуются принять против этого превентивные меры. Мы продолжим развивать этот проект в позитивном направлении, придерживаясь закона и этики.
|
400 |
+
|
401 |
+
Подразумевается, что пользователи этого программного обеспечения будут использовать его ответственно, соблюдая локальное законодательство. Если используется лицо реального человека, пользователь обязан получить согласие заинтересованного лица и четко указать, что это дипфейк при размещении контента в Интернете. **Разработчики и Со-авторы данного программного обеспечения не несут ответственности за действия конечных пользователей.**
|
402 |
+
|
403 |
+
Используя данное расширение, вы соглашаетесь не создавать материалы, которые:
|
404 |
+
- нарушают какие-либо действующие законы тех или иных государств или международных организаций;
|
405 |
+
- причиняют какой-либо вред человеку или лицам;
|
406 |
+
- пропагандируют любую информацию (как общедоступную, так и личную) или изображения (как общедоступные, так и личные), которые могут быть направлены на причинение вреда;
|
407 |
+
- используются для распространения дезинформации;
|
408 |
+
- нацелены на уязвимые группы людей.
|
409 |
+
|
410 |
+
Данное программное обеспечение использует предварительно обученные модели `buffalo_l` и `inswapper_128.onnx`, представленные разработчиками [InsightFace](https://github.com/deepinsight/insightface/). Эти модели распространяются при следующих условиях:
|
411 |
+
|
412 |
+
[Перевод из текста лицензии insighface](https://github.com/deepinsight/insightface/tree/master/python-package): Предварительно обученные модели InsightFace доступны только для некоммерческих исследовательс��их целей. Сюда входят как модели с автоматической загрузкой, так и модели, загруженные вручную.
|
413 |
+
|
414 |
+
Пользователи данного программного обеспечения должны строго соблюдать данные условия использования. Разработчики и Со-авторы данного программного продукта не несут ответственности за неправильное использование предварительно обученных моделей InsightFace.
|
415 |
+
|
416 |
+
Обратите внимание: если вы собираетесь использовать это программное обеспечение в каких-либо коммерческих целях, вам необходимо будет обучить свои собственные модели или найти модели, которые можно использовать в коммерческих целях.
|
417 |
+
|
418 |
+
### Хэш файлов моделей
|
419 |
+
|
420 |
+
#### Безопасные для использования модели имеют следующий хэш:
|
421 |
+
|
422 |
+
inswapper_128.onnx
|
423 |
+
```
|
424 |
+
MD5:a3a155b90354160350efd66fed6b3d80
|
425 |
+
SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
|
426 |
+
```
|
427 |
+
|
428 |
+
1k3d68.onnx
|
429 |
+
|
430 |
+
```
|
431 |
+
MD5:6fb94fcdb0055e3638bf9158e6a108f4
|
432 |
+
SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
433 |
+
```
|
434 |
+
|
435 |
+
2d106det.onnx
|
436 |
+
|
437 |
+
```
|
438 |
+
MD5:a3613ef9eb3662b4ef88eb90db1fcf26
|
439 |
+
SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
440 |
+
```
|
441 |
+
|
442 |
+
det_10g.onnx
|
443 |
+
|
444 |
+
```
|
445 |
+
MD5:4c10eef5c9e168357a16fdd580fa8371
|
446 |
+
SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
447 |
+
```
|
448 |
+
|
449 |
+
genderage.onnx
|
450 |
+
|
451 |
+
```
|
452 |
+
MD5:81c77ba87ab38163b0dec6b26f8e2af2
|
453 |
+
SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
454 |
+
```
|
455 |
+
|
456 |
+
w600k_r50.onnx
|
457 |
+
|
458 |
+
```
|
459 |
+
MD5:80248d427976241cbd1343889ed132b3
|
460 |
+
SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
|
461 |
+
```
|
462 |
+
|
463 |
+
**Пожалуйста, сравните хэш, если вы скачиваете данные модели из непроверенных источников**
|
464 |
+
|
465 |
+
<a name="credits">
|
466 |
+
|
467 |
+
## Благодарности и авторы компонентов
|
468 |
+
|
469 |
+
<details>
|
470 |
+
<summary><a>Нажмите, чтобы посмотреть</a></summary>
|
471 |
+
|
472 |
+
<br>
|
473 |
+
|
474 |
+
|файл|источник|лицензия|
|
475 |
+
|----|--------|--------|
|
476 |
+
|[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) |  |
|
477 |
+
| [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) |  |
|
478 |
+
| [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) |  |
|
479 |
+
| [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) |  |
|
480 |
+
| [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) |  |
|
481 |
+
| [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) |  |
|
482 |
+
|
483 |
+
[BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup) <br>
|
484 |
+
[facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao) <br>
|
485 |
+
|
486 |
+
[@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - оригинальное приложение Roop <br>
|
487 |
+
[@ssitu](https://github.com/ssitu) - первая версия расширения с поддержкой ComfyUI [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop)
|
488 |
+
|
489 |
+
</details>
|
490 |
+
|
491 |
+
<a name="note">
|
492 |
+
|
493 |
+
### Обратите внимание!
|
494 |
+
|
495 |
+
**Если у вас возникли какие-либо ошибки при очередном использовании Нода ReActor - не торопитесь открывать Issue, для начала попробуйте удалить текущий Нод из вашего рабочего пространства и добавить его снова**
|
496 |
+
|
497 |
+
**ReActor Node периодически получает обновления, появляются новые функции, из-за чего имеющ��йся Нод может работать с ошибками или не работать вовсе**
|
custom_nodes/ComfyUI-ReActor/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
repo_dir = os.path.dirname(os.path.realpath(__file__))
|
5 |
+
sys.path.insert(0, repo_dir)
|
6 |
+
original_modules = sys.modules.copy()
|
7 |
+
|
8 |
+
# Place aside existing modules if using a1111 web ui
|
9 |
+
modules_used = [
|
10 |
+
"modules",
|
11 |
+
"modules.images",
|
12 |
+
"modules.processing",
|
13 |
+
"modules.scripts_postprocessing",
|
14 |
+
"modules.scripts",
|
15 |
+
"modules.shared",
|
16 |
+
]
|
17 |
+
original_webui_modules = {}
|
18 |
+
for module in modules_used:
|
19 |
+
if module in sys.modules:
|
20 |
+
original_webui_modules[module] = sys.modules.pop(module)
|
21 |
+
|
22 |
+
# Proceed with node setup
|
23 |
+
from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
24 |
+
|
25 |
+
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
26 |
+
|
27 |
+
# Clean up imports
|
28 |
+
# Remove repo directory from path
|
29 |
+
sys.path.remove(repo_dir)
|
30 |
+
# Remove any new modules
|
31 |
+
modules_to_remove = []
|
32 |
+
for module in sys.modules:
|
33 |
+
if module not in original_modules and not module.startswith("google.protobuf") and not module.startswith("onnx") and not module.startswith("cv2"):
|
34 |
+
modules_to_remove.append(module)
|
35 |
+
for module in modules_to_remove:
|
36 |
+
del sys.modules[module]
|
37 |
+
|
38 |
+
# Restore original modules
|
39 |
+
sys.modules.update(original_webui_modules)
|
custom_nodes/ComfyUI-ReActor/install.bat
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@echo off
|
2 |
+
setlocal enabledelayedexpansion
|
3 |
+
|
4 |
+
:: Try to use embedded python first
|
5 |
+
if exist ..\..\..\python_embeded\python.exe (
|
6 |
+
:: Use the embedded python
|
7 |
+
set PYTHON=..\..\..\python_embeded\python.exe
|
8 |
+
) else (
|
9 |
+
:: Embedded python not found, check for python in the PATH
|
10 |
+
for /f "tokens=* USEBACKQ" %%F in (`python --version 2^>^&1`) do (
|
11 |
+
set PYTHON_VERSION=%%F
|
12 |
+
)
|
13 |
+
if errorlevel 1 (
|
14 |
+
echo I couldn't find an embedded version of Python, nor one in the Windows PATH. Please install manually.
|
15 |
+
pause
|
16 |
+
exit /b 1
|
17 |
+
) else (
|
18 |
+
:: Use python from the PATH (if it's the right version and the user agrees)
|
19 |
+
echo I couldn't find an embedded version of Python, but I did find !PYTHON_VERSION! in your Windows PATH.
|
20 |
+
echo Would you like to proceed with the install using that version? (Y/N^)
|
21 |
+
set /p USE_PYTHON=
|
22 |
+
if /i "!USE_PYTHON!"=="Y" (
|
23 |
+
set PYTHON=python
|
24 |
+
) else (
|
25 |
+
echo Okay. Please install manually.
|
26 |
+
pause
|
27 |
+
exit /b 1
|
28 |
+
)
|
29 |
+
)
|
30 |
+
)
|
31 |
+
|
32 |
+
:: Install the package
|
33 |
+
echo Installing...
|
34 |
+
%PYTHON% install.py
|
35 |
+
echo Done^!
|
36 |
+
|
37 |
+
@pause
|
custom_nodes/ComfyUI-ReActor/install.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
3 |
+
|
4 |
+
import subprocess
|
5 |
+
import os, sys
|
6 |
+
try:
|
7 |
+
from pkg_resources import get_distribution as distributions
|
8 |
+
except:
|
9 |
+
from importlib_metadata import distributions
|
10 |
+
from tqdm import tqdm
|
11 |
+
import urllib.request
|
12 |
+
from packaging import version as pv
|
13 |
+
try:
|
14 |
+
from folder_paths import models_dir
|
15 |
+
except:
|
16 |
+
from pathlib import Path
|
17 |
+
models_dir = os.path.join(Path(__file__).parents[2], "models")
|
18 |
+
|
19 |
+
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
|
20 |
+
|
21 |
+
req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
|
22 |
+
|
23 |
+
model_url = "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/inswapper_128.onnx"
|
24 |
+
model_name = os.path.basename(model_url)
|
25 |
+
models_dir_path = os.path.join(models_dir, "insightface")
|
26 |
+
model_path = os.path.join(models_dir_path, model_name)
|
27 |
+
|
28 |
+
def run_pip(*args):
|
29 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "--no-warn-script-location", *args])
|
30 |
+
|
31 |
+
def is_installed (
|
32 |
+
package: str, version: str = None, strict: bool = True
|
33 |
+
):
|
34 |
+
has_package = None
|
35 |
+
try:
|
36 |
+
has_package = distributions(package)
|
37 |
+
if has_package is not None:
|
38 |
+
if version is not None:
|
39 |
+
installed_version = has_package.version
|
40 |
+
if (installed_version != version and strict == True) or (pv.parse(installed_version) < pv.parse(version) and strict == False):
|
41 |
+
return False
|
42 |
+
else:
|
43 |
+
return True
|
44 |
+
else:
|
45 |
+
return True
|
46 |
+
else:
|
47 |
+
return False
|
48 |
+
except Exception as e:
|
49 |
+
print(f"Status: {e}")
|
50 |
+
return False
|
51 |
+
|
52 |
+
def download(url, path, name):
|
53 |
+
request = urllib.request.urlopen(url)
|
54 |
+
total = int(request.headers.get('Content-Length', 0))
|
55 |
+
with tqdm(total=total, desc=f'[ReActor] Downloading {name} to {path}', unit='B', unit_scale=True, unit_divisor=1024) as progress:
|
56 |
+
urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size))
|
57 |
+
|
58 |
+
if not os.path.exists(models_dir_path):
|
59 |
+
os.makedirs(models_dir_path)
|
60 |
+
|
61 |
+
if not os.path.exists(model_path):
|
62 |
+
download(model_url, model_path, model_name)
|
63 |
+
|
64 |
+
with open(req_file) as file:
|
65 |
+
try:
|
66 |
+
ort = "onnxruntime-gpu"
|
67 |
+
import torch
|
68 |
+
cuda_version = None
|
69 |
+
if torch.cuda.is_available():
|
70 |
+
cuda_version = torch.version.cuda
|
71 |
+
print(f"CUDA {cuda_version}")
|
72 |
+
elif torch.backends.mps.is_available() or hasattr(torch,'dml') or hasattr(torch,'privateuseone'):
|
73 |
+
ort = "onnxruntime"
|
74 |
+
if cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ <= "2.2.0": # CU12.x and torch<=2.2.0
|
75 |
+
print(f"Torch: {torch.torch_version.__version__}")
|
76 |
+
if not is_installed(ort,"1.17.0",False):
|
77 |
+
run_pip(ort,"--extra-index-url", "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/")
|
78 |
+
elif cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ >= "2.4.0" : # CU12.x and latest torch
|
79 |
+
print(f"Torch: {torch.torch_version.__version__}")
|
80 |
+
if not is_installed(ort,"1.20.1",False): # latest ort-gpu
|
81 |
+
run_pip(ort,"-U")
|
82 |
+
elif not is_installed(ort,"1.16.1",False):
|
83 |
+
run_pip(ort, "-U")
|
84 |
+
except Exception as e:
|
85 |
+
print(e)
|
86 |
+
print(f"Warning: Failed to install {ort}, ReActor will not work.")
|
87 |
+
raise e
|
88 |
+
strict = True
|
89 |
+
for package in file:
|
90 |
+
package_version = None
|
91 |
+
try:
|
92 |
+
package = package.strip()
|
93 |
+
if "==" in package:
|
94 |
+
package_version = package.split('==')[1]
|
95 |
+
elif ">=" in package:
|
96 |
+
package_version = package.split('>=')[1]
|
97 |
+
strict = False
|
98 |
+
if not is_installed(package,package_version,strict):
|
99 |
+
run_pip(package)
|
100 |
+
except Exception as e:
|
101 |
+
print(e)
|
102 |
+
print(f"Warning: Failed to install {package}, ReActor will not work.")
|
103 |
+
raise e
|
104 |
+
print("Ok")
|
custom_nodes/ComfyUI-ReActor/modules/__init__.py
ADDED
File without changes
|
custom_nodes/ComfyUI-ReActor/modules/images.py
ADDED
File without changes
|
custom_nodes/ComfyUI-ReActor/modules/processing.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class StableDiffusionProcessing:
|
2 |
+
|
3 |
+
def __init__(self, init_imgs):
|
4 |
+
self.init_images = init_imgs
|
5 |
+
self.width = init_imgs[0].width
|
6 |
+
self.height = init_imgs[0].height
|
7 |
+
self.extra_generation_params = {}
|
8 |
+
|
9 |
+
|
10 |
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
11 |
+
|
12 |
+
def __init__(self, init_img):
|
13 |
+
super().__init__(init_img)
|
custom_nodes/ComfyUI-ReActor/modules/scripts.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class Script:
|
5 |
+
pass
|
6 |
+
|
7 |
+
|
8 |
+
def basedir():
|
9 |
+
return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
|
11 |
+
|
12 |
+
class PostprocessImageArgs:
|
13 |
+
pass
|
custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py
ADDED
File without changes
|
custom_nodes/ComfyUI-ReActor/modules/shared.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class Options:
|
2 |
+
img2img_background_color = "#ffffff" # Set to white for now
|
3 |
+
|
4 |
+
|
5 |
+
class State:
|
6 |
+
interrupted = False
|
7 |
+
|
8 |
+
def begin(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def end(self):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
opts = Options()
|
16 |
+
state = State()
|
17 |
+
cmd_opts = None
|
18 |
+
sd_upscalers = []
|
19 |
+
face_restorers = []
|
custom_nodes/ComfyUI-ReActor/nodes.py
ADDED
@@ -0,0 +1,1364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, sys
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as torchfn
|
6 |
+
from torchvision.transforms.functional import normalize
|
7 |
+
from torchvision.ops import masks_to_boxes
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import cv2
|
11 |
+
import math
|
12 |
+
from typing import List
|
13 |
+
from PIL import Image
|
14 |
+
from scipy import stats
|
15 |
+
from insightface.app.common import Face
|
16 |
+
from segment_anything import sam_model_registry
|
17 |
+
|
18 |
+
from modules.processing import StableDiffusionProcessingImg2Img
|
19 |
+
from modules.shared import state
|
20 |
+
# from comfy_extras.chainner_models import model_loading
|
21 |
+
import comfy.model_management as model_management
|
22 |
+
import comfy.utils
|
23 |
+
import folder_paths
|
24 |
+
|
25 |
+
import scripts.reactor_version
|
26 |
+
from r_chainner import model_loading
|
27 |
+
from scripts.reactor_faceswap import (
|
28 |
+
FaceSwapScript,
|
29 |
+
get_models,
|
30 |
+
get_current_faces_model,
|
31 |
+
analyze_faces,
|
32 |
+
half_det_size,
|
33 |
+
providers
|
34 |
+
)
|
35 |
+
from scripts.reactor_swapper import (
|
36 |
+
unload_all_models,
|
37 |
+
)
|
38 |
+
from scripts.reactor_logger import logger
|
39 |
+
from reactor_utils import (
|
40 |
+
batch_tensor_to_pil,
|
41 |
+
batched_pil_to_tensor,
|
42 |
+
tensor_to_pil,
|
43 |
+
img2tensor,
|
44 |
+
tensor2img,
|
45 |
+
save_face_model,
|
46 |
+
load_face_model,
|
47 |
+
download,
|
48 |
+
set_ort_session,
|
49 |
+
prepare_cropped_face,
|
50 |
+
normalize_cropped_face,
|
51 |
+
add_folder_path_and_extensions,
|
52 |
+
rgba2rgb_tensor
|
53 |
+
)
|
54 |
+
from reactor_patcher import apply_patch
|
55 |
+
from r_facelib.utils.face_restoration_helper import FaceRestoreHelper
|
56 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
57 |
+
import scripts.r_archs.codeformer_arch
|
58 |
+
import scripts.r_masking.subcore as subcore
|
59 |
+
import scripts.r_masking.core as core
|
60 |
+
import scripts.r_masking.segs as masking_segs
|
61 |
+
|
62 |
+
import scripts.reactor_sfw as sfw
|
63 |
+
|
64 |
+
|
65 |
+
models_dir = folder_paths.models_dir
|
66 |
+
REACTOR_MODELS_PATH = os.path.join(models_dir, "reactor")
|
67 |
+
FACE_MODELS_PATH = os.path.join(REACTOR_MODELS_PATH, "faces")
|
68 |
+
NSFWDET_MODEL_PATH = os.path.join(models_dir, "nsfw_detector","vit-base-nsfw-detector")
|
69 |
+
|
70 |
+
if not os.path.exists(REACTOR_MODELS_PATH):
|
71 |
+
os.makedirs(REACTOR_MODELS_PATH)
|
72 |
+
if not os.path.exists(FACE_MODELS_PATH):
|
73 |
+
os.makedirs(FACE_MODELS_PATH)
|
74 |
+
|
75 |
+
dir_facerestore_models = os.path.join(models_dir, "facerestore_models")
|
76 |
+
os.makedirs(dir_facerestore_models, exist_ok=True)
|
77 |
+
folder_paths.folder_names_and_paths["facerestore_models"] = ([dir_facerestore_models], folder_paths.supported_pt_extensions)
|
78 |
+
|
79 |
+
BLENDED_FACE_MODEL = None
|
80 |
+
FACE_SIZE: int = 512
|
81 |
+
FACE_HELPER = None
|
82 |
+
|
83 |
+
if "ultralytics" not in folder_paths.folder_names_and_paths:
|
84 |
+
add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(models_dir, "ultralytics", "bbox")], folder_paths.supported_pt_extensions)
|
85 |
+
add_folder_path_and_extensions("ultralytics_segm", [os.path.join(models_dir, "ultralytics", "segm")], folder_paths.supported_pt_extensions)
|
86 |
+
add_folder_path_and_extensions("ultralytics", [os.path.join(models_dir, "ultralytics")], folder_paths.supported_pt_extensions)
|
87 |
+
if "sams" not in folder_paths.folder_names_and_paths:
|
88 |
+
add_folder_path_and_extensions("sams", [os.path.join(models_dir, "sams")], folder_paths.supported_pt_extensions)
|
89 |
+
|
90 |
+
def get_facemodels():
|
91 |
+
models_path = os.path.join(FACE_MODELS_PATH, "*")
|
92 |
+
models = glob.glob(models_path)
|
93 |
+
models = [x for x in models if x.endswith(".safetensors")]
|
94 |
+
return models
|
95 |
+
|
96 |
+
def get_restorers():
|
97 |
+
models_path = os.path.join(models_dir, "facerestore_models/*")
|
98 |
+
models = glob.glob(models_path)
|
99 |
+
models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))]
|
100 |
+
if len(models) == 0:
|
101 |
+
fr_urls = [
|
102 |
+
"https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.3.pth",
|
103 |
+
"https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.4.pth",
|
104 |
+
"https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/codeformer-v0.1.0.pth",
|
105 |
+
"https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GPEN-BFR-512.onnx",
|
106 |
+
]
|
107 |
+
for model_url in fr_urls:
|
108 |
+
model_name = os.path.basename(model_url)
|
109 |
+
model_path = os.path.join(dir_facerestore_models, model_name)
|
110 |
+
download(model_url, model_path, model_name)
|
111 |
+
models = glob.glob(models_path)
|
112 |
+
models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))]
|
113 |
+
return models
|
114 |
+
|
115 |
+
def get_model_names(get_models):
|
116 |
+
models = get_models()
|
117 |
+
names = []
|
118 |
+
for x in models:
|
119 |
+
names.append(os.path.basename(x))
|
120 |
+
names.sort(key=str.lower)
|
121 |
+
names.insert(0, "none")
|
122 |
+
return names
|
123 |
+
|
124 |
+
def model_names():
|
125 |
+
models = get_models()
|
126 |
+
return {os.path.basename(x): x for x in models}
|
127 |
+
|
128 |
+
|
129 |
+
class reactor:
|
130 |
+
@classmethod
|
131 |
+
def INPUT_TYPES(s):
|
132 |
+
return {
|
133 |
+
"required": {
|
134 |
+
"enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
|
135 |
+
"input_image": ("IMAGE",),
|
136 |
+
"swap_model": (list(model_names().keys()),),
|
137 |
+
"facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
|
138 |
+
"face_restore_model": (get_model_names(get_restorers),),
|
139 |
+
"face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
|
140 |
+
"codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
|
141 |
+
"detect_gender_input": (["no","female","male"], {"default": "no"}),
|
142 |
+
"detect_gender_source": (["no","female","male"], {"default": "no"}),
|
143 |
+
"input_faces_index": ("STRING", {"default": "0"}),
|
144 |
+
"source_faces_index": ("STRING", {"default": "0"}),
|
145 |
+
"console_log_level": ([0, 1, 2], {"default": 1}),
|
146 |
+
},
|
147 |
+
"optional": {
|
148 |
+
"source_image": ("IMAGE",),
|
149 |
+
"face_model": ("FACE_MODEL",),
|
150 |
+
"face_boost": ("FACE_BOOST",),
|
151 |
+
},
|
152 |
+
"hidden": {"faces_order": "FACES_ORDER"},
|
153 |
+
}
|
154 |
+
|
155 |
+
RETURN_TYPES = ("IMAGE","FACE_MODEL")
|
156 |
+
FUNCTION = "execute"
|
157 |
+
CATEGORY = "🌌 ReActor"
|
158 |
+
|
159 |
+
def __init__(self):
|
160 |
+
# self.face_helper = None
|
161 |
+
self.faces_order = ["large-small", "large-small"]
|
162 |
+
# self.face_size = FACE_SIZE
|
163 |
+
self.face_boost_enabled = False
|
164 |
+
self.restore = True
|
165 |
+
self.boost_model = None
|
166 |
+
self.interpolation = "Bicubic"
|
167 |
+
self.boost_model_visibility = 1
|
168 |
+
self.boost_cf_weight = 0.5
|
169 |
+
|
170 |
+
def restore_face(
|
171 |
+
self,
|
172 |
+
input_image,
|
173 |
+
face_restore_model,
|
174 |
+
face_restore_visibility,
|
175 |
+
codeformer_weight,
|
176 |
+
facedetection,
|
177 |
+
):
|
178 |
+
|
179 |
+
result = input_image
|
180 |
+
|
181 |
+
if face_restore_model != "none" and not model_management.processing_interrupted():
|
182 |
+
|
183 |
+
global FACE_SIZE, FACE_HELPER
|
184 |
+
|
185 |
+
self.face_helper = FACE_HELPER
|
186 |
+
|
187 |
+
faceSize = 512
|
188 |
+
if "1024" in face_restore_model.lower():
|
189 |
+
faceSize = 1024
|
190 |
+
elif "2048" in face_restore_model.lower():
|
191 |
+
faceSize = 2048
|
192 |
+
|
193 |
+
logger.status(f"Restoring with {face_restore_model} | Face Size is set to {faceSize}")
|
194 |
+
|
195 |
+
model_path = folder_paths.get_full_path("facerestore_models", face_restore_model)
|
196 |
+
|
197 |
+
device = model_management.get_torch_device()
|
198 |
+
|
199 |
+
if "codeformer" in face_restore_model.lower():
|
200 |
+
|
201 |
+
codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
|
202 |
+
dim_embd=512,
|
203 |
+
codebook_size=1024,
|
204 |
+
n_head=8,
|
205 |
+
n_layers=9,
|
206 |
+
connect_list=["32", "64", "128", "256"],
|
207 |
+
).to(device)
|
208 |
+
checkpoint = torch.load(model_path)["params_ema"]
|
209 |
+
codeformer_net.load_state_dict(checkpoint)
|
210 |
+
facerestore_model = codeformer_net.eval()
|
211 |
+
|
212 |
+
elif ".onnx" in face_restore_model:
|
213 |
+
|
214 |
+
ort_session = set_ort_session(model_path, providers=providers)
|
215 |
+
ort_session_inputs = {}
|
216 |
+
facerestore_model = ort_session
|
217 |
+
|
218 |
+
else:
|
219 |
+
|
220 |
+
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
221 |
+
facerestore_model = model_loading.load_state_dict(sd).eval()
|
222 |
+
facerestore_model.to(device)
|
223 |
+
|
224 |
+
if faceSize != FACE_SIZE or self.face_helper is None:
|
225 |
+
self.face_helper = FaceRestoreHelper(1, face_size=faceSize, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device)
|
226 |
+
FACE_SIZE = faceSize
|
227 |
+
FACE_HELPER = self.face_helper
|
228 |
+
|
229 |
+
image_np = 255. * result.numpy()
|
230 |
+
|
231 |
+
total_images = image_np.shape[0]
|
232 |
+
|
233 |
+
out_images = []
|
234 |
+
|
235 |
+
for i in range(total_images):
|
236 |
+
|
237 |
+
if total_images > 1:
|
238 |
+
logger.status(f"Restoring {i+1}")
|
239 |
+
|
240 |
+
cur_image_np = image_np[i,:, :, ::-1]
|
241 |
+
|
242 |
+
original_resolution = cur_image_np.shape[0:2]
|
243 |
+
|
244 |
+
if facerestore_model is None or self.face_helper is None:
|
245 |
+
return result
|
246 |
+
|
247 |
+
self.face_helper.clean_all()
|
248 |
+
self.face_helper.read_image(cur_image_np)
|
249 |
+
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
|
250 |
+
self.face_helper.align_warp_face()
|
251 |
+
|
252 |
+
restored_face = None
|
253 |
+
|
254 |
+
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
|
255 |
+
|
256 |
+
# if ".pth" in face_restore_model:
|
257 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
258 |
+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
259 |
+
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
260 |
+
|
261 |
+
try:
|
262 |
+
|
263 |
+
with torch.no_grad():
|
264 |
+
|
265 |
+
if ".onnx" in face_restore_model: # ONNX models
|
266 |
+
|
267 |
+
for ort_session_input in ort_session.get_inputs():
|
268 |
+
if ort_session_input.name == "input":
|
269 |
+
cropped_face_prep = prepare_cropped_face(cropped_face)
|
270 |
+
ort_session_inputs[ort_session_input.name] = cropped_face_prep
|
271 |
+
if ort_session_input.name == "weight":
|
272 |
+
weight = np.array([ 1 ], dtype = np.double)
|
273 |
+
ort_session_inputs[ort_session_input.name] = weight
|
274 |
+
|
275 |
+
output = ort_session.run(None, ort_session_inputs)[0][0]
|
276 |
+
restored_face = normalize_cropped_face(output)
|
277 |
+
|
278 |
+
else: # PTH models
|
279 |
+
|
280 |
+
output = facerestore_model(cropped_face_t, w=codeformer_weight)[0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0]
|
281 |
+
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
|
282 |
+
|
283 |
+
del output
|
284 |
+
torch.cuda.empty_cache()
|
285 |
+
|
286 |
+
except Exception as error:
|
287 |
+
|
288 |
+
print(f"\tFailed inference: {error}", file=sys.stderr)
|
289 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
290 |
+
|
291 |
+
if face_restore_visibility < 1:
|
292 |
+
restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility
|
293 |
+
|
294 |
+
restored_face = restored_face.astype("uint8")
|
295 |
+
self.face_helper.add_restored_face(restored_face)
|
296 |
+
|
297 |
+
self.face_helper.get_inverse_affine(None)
|
298 |
+
|
299 |
+
restored_img = self.face_helper.paste_faces_to_input_image()
|
300 |
+
restored_img = restored_img[:, :, ::-1]
|
301 |
+
|
302 |
+
if original_resolution != restored_img.shape[0:2]:
|
303 |
+
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_AREA)
|
304 |
+
|
305 |
+
self.face_helper.clean_all()
|
306 |
+
|
307 |
+
# out_images[i] = restored_img
|
308 |
+
out_images.append(restored_img)
|
309 |
+
|
310 |
+
if state.interrupted or model_management.processing_interrupted():
|
311 |
+
logger.status("Interrupted by User")
|
312 |
+
return input_image
|
313 |
+
|
314 |
+
restored_img_np = np.array(out_images).astype(np.float32) / 255.0
|
315 |
+
restored_img_tensor = torch.from_numpy(restored_img_np)
|
316 |
+
|
317 |
+
result = restored_img_tensor
|
318 |
+
|
319 |
+
return result
|
320 |
+
|
321 |
+
def execute(self, enabled, input_image, swap_model, detect_gender_source, detect_gender_input, source_faces_index, input_faces_index, console_log_level, face_restore_model,face_restore_visibility, codeformer_weight, facedetection, source_image=None, face_model=None, faces_order=None, face_boost=None):
|
322 |
+
|
323 |
+
if face_boost is not None:
|
324 |
+
self.face_boost_enabled = face_boost["enabled"]
|
325 |
+
self.boost_model = face_boost["boost_model"]
|
326 |
+
self.interpolation = face_boost["interpolation"]
|
327 |
+
self.boost_model_visibility = face_boost["visibility"]
|
328 |
+
self.boost_cf_weight = face_boost["codeformer_weight"]
|
329 |
+
self.restore = face_boost["restore_with_main_after"]
|
330 |
+
else:
|
331 |
+
self.face_boost_enabled = False
|
332 |
+
|
333 |
+
if faces_order is None:
|
334 |
+
faces_order = self.faces_order
|
335 |
+
|
336 |
+
apply_patch(console_log_level)
|
337 |
+
|
338 |
+
if not enabled:
|
339 |
+
return (input_image,face_model)
|
340 |
+
elif source_image is None and face_model is None:
|
341 |
+
logger.error("Please provide 'source_image' or `face_model`")
|
342 |
+
return (input_image,face_model)
|
343 |
+
|
344 |
+
if face_model == "none":
|
345 |
+
face_model = None
|
346 |
+
|
347 |
+
script = FaceSwapScript()
|
348 |
+
pil_images = batch_tensor_to_pil(input_image)
|
349 |
+
|
350 |
+
# NSFW checker
|
351 |
+
logger.status("Checking for any unsafe content")
|
352 |
+
pil_images_sfw = []
|
353 |
+
tmp_img = "reactor_tmp.png"
|
354 |
+
for img in pil_images:
|
355 |
+
if state.interrupted or model_management.processing_interrupted():
|
356 |
+
logger.status("Interrupted by User")
|
357 |
+
break
|
358 |
+
img.save(tmp_img)
|
359 |
+
if not sfw.nsfw_image(tmp_img, NSFWDET_MODEL_PATH):
|
360 |
+
pil_images_sfw.append(img)
|
361 |
+
if os.path.exists(tmp_img):
|
362 |
+
os.remove(tmp_img)
|
363 |
+
pil_images = pil_images_sfw
|
364 |
+
# # #
|
365 |
+
|
366 |
+
if len(pil_images) > 0:
|
367 |
+
|
368 |
+
if source_image is not None:
|
369 |
+
source = tensor_to_pil(source_image)
|
370 |
+
else:
|
371 |
+
source = None
|
372 |
+
p = StableDiffusionProcessingImg2Img(pil_images)
|
373 |
+
script.process(
|
374 |
+
p=p,
|
375 |
+
img=source,
|
376 |
+
enable=True,
|
377 |
+
source_faces_index=source_faces_index,
|
378 |
+
faces_index=input_faces_index,
|
379 |
+
model=swap_model,
|
380 |
+
swap_in_source=True,
|
381 |
+
swap_in_generated=True,
|
382 |
+
gender_source=detect_gender_source,
|
383 |
+
gender_target=detect_gender_input,
|
384 |
+
face_model=face_model,
|
385 |
+
faces_order=faces_order,
|
386 |
+
# face boost:
|
387 |
+
face_boost_enabled=self.face_boost_enabled,
|
388 |
+
face_restore_model=self.boost_model,
|
389 |
+
face_restore_visibility=self.boost_model_visibility,
|
390 |
+
codeformer_weight=self.boost_cf_weight,
|
391 |
+
interpolation=self.interpolation,
|
392 |
+
)
|
393 |
+
result = batched_pil_to_tensor(p.init_images)
|
394 |
+
|
395 |
+
if face_model is None:
|
396 |
+
current_face_model = get_current_faces_model()
|
397 |
+
face_model_to_provide = current_face_model[0] if (current_face_model is not None and len(current_face_model) > 0) else face_model
|
398 |
+
else:
|
399 |
+
face_model_to_provide = face_model
|
400 |
+
|
401 |
+
if self.restore or not self.face_boost_enabled:
|
402 |
+
result = reactor.restore_face(self,result,face_restore_model,face_restore_visibility,codeformer_weight,facedetection)
|
403 |
+
|
404 |
+
else:
|
405 |
+
image_black = Image.new("RGB", (512, 512))
|
406 |
+
result = batched_pil_to_tensor([image_black])
|
407 |
+
face_model_to_provide = None
|
408 |
+
|
409 |
+
return (result,face_model_to_provide)
|
410 |
+
|
411 |
+
|
412 |
+
class ReActorPlusOpt:
|
413 |
+
@classmethod
|
414 |
+
def INPUT_TYPES(s):
|
415 |
+
return {
|
416 |
+
"required": {
|
417 |
+
"enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
|
418 |
+
"input_image": ("IMAGE",),
|
419 |
+
"swap_model": (list(model_names().keys()),),
|
420 |
+
"facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
|
421 |
+
"face_restore_model": (get_model_names(get_restorers),),
|
422 |
+
"face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
|
423 |
+
"codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
|
424 |
+
},
|
425 |
+
"optional": {
|
426 |
+
"source_image": ("IMAGE",),
|
427 |
+
"face_model": ("FACE_MODEL",),
|
428 |
+
"options": ("OPTIONS",),
|
429 |
+
"face_boost": ("FACE_BOOST",),
|
430 |
+
}
|
431 |
+
}
|
432 |
+
|
433 |
+
RETURN_TYPES = ("IMAGE","FACE_MODEL")
|
434 |
+
FUNCTION = "execute"
|
435 |
+
CATEGORY = "🌌 ReActor"
|
436 |
+
|
437 |
+
def __init__(self):
|
438 |
+
# self.face_helper = None
|
439 |
+
self.faces_order = ["large-small", "large-small"]
|
440 |
+
self.detect_gender_input = "no"
|
441 |
+
self.detect_gender_source = "no"
|
442 |
+
self.input_faces_index = "0"
|
443 |
+
self.source_faces_index = "0"
|
444 |
+
self.console_log_level = 1
|
445 |
+
# self.face_size = 512
|
446 |
+
self.face_boost_enabled = False
|
447 |
+
self.restore = True
|
448 |
+
self.boost_model = None
|
449 |
+
self.interpolation = "Bicubic"
|
450 |
+
self.boost_model_visibility = 1
|
451 |
+
self.boost_cf_weight = 0.5
|
452 |
+
|
453 |
+
def execute(self, enabled, input_image, swap_model, facedetection, face_restore_model, face_restore_visibility, codeformer_weight, source_image=None, face_model=None, options=None, face_boost=None):
|
454 |
+
|
455 |
+
if options is not None:
|
456 |
+
self.faces_order = [options["input_faces_order"], options["source_faces_order"]]
|
457 |
+
self.console_log_level = options["console_log_level"]
|
458 |
+
self.detect_gender_input = options["detect_gender_input"]
|
459 |
+
self.detect_gender_source = options["detect_gender_source"]
|
460 |
+
self.input_faces_index = options["input_faces_index"]
|
461 |
+
self.source_faces_index = options["source_faces_index"]
|
462 |
+
|
463 |
+
if face_boost is not None:
|
464 |
+
self.face_boost_enabled = face_boost["enabled"]
|
465 |
+
self.restore = face_boost["restore_with_main_after"]
|
466 |
+
else:
|
467 |
+
self.face_boost_enabled = False
|
468 |
+
|
469 |
+
result = reactor.execute(
|
470 |
+
self,enabled,input_image,swap_model,self.detect_gender_source,self.detect_gender_input,self.source_faces_index,self.input_faces_index,self.console_log_level,face_restore_model,face_restore_visibility,codeformer_weight,facedetection,source_image,face_model,self.faces_order, face_boost=face_boost
|
471 |
+
)
|
472 |
+
|
473 |
+
return result
|
474 |
+
|
475 |
+
|
476 |
+
class LoadFaceModel:
|
477 |
+
@classmethod
|
478 |
+
def INPUT_TYPES(s):
|
479 |
+
return {
|
480 |
+
"required": {
|
481 |
+
"face_model": (get_model_names(get_facemodels),),
|
482 |
+
}
|
483 |
+
}
|
484 |
+
|
485 |
+
RETURN_TYPES = ("FACE_MODEL",)
|
486 |
+
FUNCTION = "load_model"
|
487 |
+
CATEGORY = "🌌 ReActor"
|
488 |
+
|
489 |
+
def load_model(self, face_model):
|
490 |
+
self.face_model = face_model
|
491 |
+
self.face_models_path = FACE_MODELS_PATH
|
492 |
+
if self.face_model != "none":
|
493 |
+
face_model_path = os.path.join(self.face_models_path, self.face_model)
|
494 |
+
out = load_face_model(face_model_path)
|
495 |
+
else:
|
496 |
+
out = None
|
497 |
+
return (out, )
|
498 |
+
|
499 |
+
|
500 |
+
class ReActorWeight:
|
501 |
+
@classmethod
|
502 |
+
def INPUT_TYPES(s):
|
503 |
+
return {
|
504 |
+
"required": {
|
505 |
+
"input_image": ("IMAGE",),
|
506 |
+
"faceswap_weight": (["0%", "12.5%", "25%", "37.5%", "50%", "62.5%", "75%", "87.5%", "100%"], {"default": "50%"}),
|
507 |
+
},
|
508 |
+
"optional": {
|
509 |
+
"source_image": ("IMAGE",),
|
510 |
+
"face_model": ("FACE_MODEL",),
|
511 |
+
}
|
512 |
+
}
|
513 |
+
|
514 |
+
RETURN_TYPES = ("IMAGE","FACE_MODEL")
|
515 |
+
RETURN_NAMES = ("INPUT_IMAGE","FACE_MODEL")
|
516 |
+
FUNCTION = "set_weight"
|
517 |
+
|
518 |
+
OUTPUT_NODE = True
|
519 |
+
|
520 |
+
CATEGORY = "🌌 ReActor"
|
521 |
+
|
522 |
+
def set_weight(self, input_image, faceswap_weight, face_model=None, source_image=None):
|
523 |
+
|
524 |
+
if input_image is None:
|
525 |
+
logger.error("Please provide `input_image`")
|
526 |
+
return (input_image,None)
|
527 |
+
|
528 |
+
if source_image is None and face_model is None:
|
529 |
+
logger.error("Please provide `source_image` or `face_model`")
|
530 |
+
return (input_image,None)
|
531 |
+
|
532 |
+
weight = float(faceswap_weight.split("%")[0])
|
533 |
+
|
534 |
+
images = []
|
535 |
+
faces = [] if face_model is None else [face_model]
|
536 |
+
embeddings = [] if face_model is None else [face_model.embedding]
|
537 |
+
|
538 |
+
if weight == 0:
|
539 |
+
images = [input_image]
|
540 |
+
faces = []
|
541 |
+
embeddings = []
|
542 |
+
elif weight == 100:
|
543 |
+
if face_model is None:
|
544 |
+
images = [source_image]
|
545 |
+
else:
|
546 |
+
if weight > 50:
|
547 |
+
images = [input_image]
|
548 |
+
count = round(100/(100-weight))
|
549 |
+
else:
|
550 |
+
if face_model is None:
|
551 |
+
images = [source_image]
|
552 |
+
count = round(100/(weight))
|
553 |
+
for i in range(count-1):
|
554 |
+
if weight > 50:
|
555 |
+
if face_model is None:
|
556 |
+
images.append(source_image)
|
557 |
+
else:
|
558 |
+
faces.append(face_model)
|
559 |
+
embeddings.append(face_model.embedding)
|
560 |
+
else:
|
561 |
+
images.append(input_image)
|
562 |
+
|
563 |
+
images_list: List[Image.Image] = []
|
564 |
+
|
565 |
+
apply_patch(1)
|
566 |
+
|
567 |
+
if len(images) > 0:
|
568 |
+
|
569 |
+
for image in images:
|
570 |
+
img = tensor_to_pil(image)
|
571 |
+
images_list.append(img)
|
572 |
+
|
573 |
+
for image in images_list:
|
574 |
+
face = BuildFaceModel.build_face_model(self,image)
|
575 |
+
if isinstance(face, str):
|
576 |
+
continue
|
577 |
+
faces.append(face)
|
578 |
+
embeddings.append(face.embedding)
|
579 |
+
|
580 |
+
if len(faces) > 0:
|
581 |
+
blended_embedding = np.mean(embeddings, axis=0)
|
582 |
+
blended_face = Face(
|
583 |
+
bbox=faces[0].bbox,
|
584 |
+
kps=faces[0].kps,
|
585 |
+
det_score=faces[0].det_score,
|
586 |
+
landmark_3d_68=faces[0].landmark_3d_68,
|
587 |
+
pose=faces[0].pose,
|
588 |
+
landmark_2d_106=faces[0].landmark_2d_106,
|
589 |
+
embedding=blended_embedding,
|
590 |
+
gender=faces[0].gender,
|
591 |
+
age=faces[0].age
|
592 |
+
)
|
593 |
+
if blended_face is None:
|
594 |
+
no_face_msg = "Something went wrong, please try another set of images"
|
595 |
+
logger.error(no_face_msg)
|
596 |
+
|
597 |
+
return (input_image,blended_face)
|
598 |
+
|
599 |
+
|
600 |
+
class BuildFaceModel:
|
601 |
+
def __init__(self):
|
602 |
+
self.output_dir = FACE_MODELS_PATH
|
603 |
+
|
604 |
+
@classmethod
|
605 |
+
def INPUT_TYPES(s):
|
606 |
+
return {
|
607 |
+
"required": {
|
608 |
+
"save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
|
609 |
+
"send_only": ("BOOLEAN", {"default": False, "label_off": "NO", "label_on": "YES"}),
|
610 |
+
"face_model_name": ("STRING", {"default": "default"}),
|
611 |
+
"compute_method": (["Mean", "Median", "Mode"], {"default": "Mean"}),
|
612 |
+
},
|
613 |
+
"optional": {
|
614 |
+
"images": ("IMAGE",),
|
615 |
+
"face_models": ("FACE_MODEL",),
|
616 |
+
}
|
617 |
+
}
|
618 |
+
|
619 |
+
RETURN_TYPES = ("FACE_MODEL",)
|
620 |
+
FUNCTION = "blend_faces"
|
621 |
+
|
622 |
+
OUTPUT_NODE = True
|
623 |
+
|
624 |
+
CATEGORY = "🌌 ReActor"
|
625 |
+
|
626 |
+
def build_face_model(self, image: Image.Image, det_size=(640, 640)):
|
627 |
+
logging.StreamHandler.terminator = "\n"
|
628 |
+
if image is None:
|
629 |
+
error_msg = "Please load an Image"
|
630 |
+
logger.error(error_msg)
|
631 |
+
return error_msg
|
632 |
+
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
633 |
+
face_model = analyze_faces(image, det_size)
|
634 |
+
|
635 |
+
if len(face_model) == 0:
|
636 |
+
print("")
|
637 |
+
det_size_half = half_det_size(det_size)
|
638 |
+
face_model = analyze_faces(image, det_size_half)
|
639 |
+
if face_model is not None and len(face_model) > 0:
|
640 |
+
print("...........................................................", end=" ")
|
641 |
+
|
642 |
+
if face_model is not None and len(face_model) > 0:
|
643 |
+
return face_model[0]
|
644 |
+
else:
|
645 |
+
no_face_msg = "No face found, please try another image"
|
646 |
+
# logger.error(no_face_msg)
|
647 |
+
return no_face_msg
|
648 |
+
|
649 |
+
def blend_faces(self, save_mode, send_only, face_model_name, compute_method, images=None, face_models=None):
|
650 |
+
global BLENDED_FACE_MODEL
|
651 |
+
blended_face: Face = BLENDED_FACE_MODEL
|
652 |
+
|
653 |
+
if send_only and blended_face is None:
|
654 |
+
send_only = False
|
655 |
+
|
656 |
+
if (images is not None or face_models is not None) and not send_only:
|
657 |
+
|
658 |
+
faces = []
|
659 |
+
embeddings = []
|
660 |
+
|
661 |
+
apply_patch(1)
|
662 |
+
|
663 |
+
if images is not None:
|
664 |
+
images_list: List[Image.Image] = batch_tensor_to_pil(images)
|
665 |
+
|
666 |
+
n = len(images_list)
|
667 |
+
|
668 |
+
for i,image in enumerate(images_list):
|
669 |
+
logging.StreamHandler.terminator = " "
|
670 |
+
logger.status(f"Building Face Model {i+1} of {n}...")
|
671 |
+
face = self.build_face_model(image)
|
672 |
+
if isinstance(face, str):
|
673 |
+
logger.error(f"No faces found in image {i+1}, skipping")
|
674 |
+
continue
|
675 |
+
else:
|
676 |
+
print(f"{int(((i+1)/n)*100)}%")
|
677 |
+
faces.append(face)
|
678 |
+
embeddings.append(face.embedding)
|
679 |
+
|
680 |
+
elif face_models is not None:
|
681 |
+
|
682 |
+
n = len(face_models)
|
683 |
+
|
684 |
+
for i,face_model in enumerate(face_models):
|
685 |
+
logging.StreamHandler.terminator = " "
|
686 |
+
logger.status(f"Extracting Face Model {i+1} of {n}...")
|
687 |
+
face = face_model
|
688 |
+
if isinstance(face, str):
|
689 |
+
logger.error(f"No faces found for face_model {i+1}, skipping")
|
690 |
+
continue
|
691 |
+
else:
|
692 |
+
print(f"{int(((i+1)/n)*100)}%")
|
693 |
+
faces.append(face)
|
694 |
+
embeddings.append(face.embedding)
|
695 |
+
|
696 |
+
logging.StreamHandler.terminator = "\n"
|
697 |
+
if len(faces) > 0:
|
698 |
+
# compute_method_name = "Mean" if compute_method == 0 else "Median" if compute_method == 1 else "Mode"
|
699 |
+
logger.status(f"Blending with Compute Method '{compute_method}'...")
|
700 |
+
blended_embedding = np.mean(embeddings, axis=0) if compute_method == "Mean" else np.median(embeddings, axis=0) if compute_method == "Median" else stats.mode(embeddings, axis=0)[0].astype(np.float32)
|
701 |
+
blended_face = Face(
|
702 |
+
bbox=faces[0].bbox,
|
703 |
+
kps=faces[0].kps,
|
704 |
+
det_score=faces[0].det_score,
|
705 |
+
landmark_3d_68=faces[0].landmark_3d_68,
|
706 |
+
pose=faces[0].pose,
|
707 |
+
landmark_2d_106=faces[0].landmark_2d_106,
|
708 |
+
embedding=blended_embedding,
|
709 |
+
gender=faces[0].gender,
|
710 |
+
age=faces[0].age
|
711 |
+
)
|
712 |
+
if blended_face is not None:
|
713 |
+
BLENDED_FACE_MODEL = blended_face
|
714 |
+
if save_mode:
|
715 |
+
face_model_path = os.path.join(FACE_MODELS_PATH, face_model_name + ".safetensors")
|
716 |
+
save_face_model(blended_face,face_model_path)
|
717 |
+
# done_msg = f"Face model has been saved to '{face_model_path}'"
|
718 |
+
# logger.status(done_msg)
|
719 |
+
logger.status("--Done!--")
|
720 |
+
# return (blended_face,)
|
721 |
+
else:
|
722 |
+
no_face_msg = "Something went wrong, please try another set of images"
|
723 |
+
logger.error(no_face_msg)
|
724 |
+
# return (blended_face,)
|
725 |
+
# logger.status("--Done!--")
|
726 |
+
if images is None and face_models is None:
|
727 |
+
logger.error("Please provide `images` or `face_models`")
|
728 |
+
return (blended_face,)
|
729 |
+
|
730 |
+
|
731 |
+
class SaveFaceModel:
|
732 |
+
def __init__(self):
|
733 |
+
self.output_dir = FACE_MODELS_PATH
|
734 |
+
|
735 |
+
@classmethod
|
736 |
+
def INPUT_TYPES(s):
|
737 |
+
return {
|
738 |
+
"required": {
|
739 |
+
"save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
|
740 |
+
"face_model_name": ("STRING", {"default": "default"}),
|
741 |
+
"select_face_index": ("INT", {"default": 0, "min": 0}),
|
742 |
+
},
|
743 |
+
"optional": {
|
744 |
+
"image": ("IMAGE",),
|
745 |
+
"face_model": ("FACE_MODEL",),
|
746 |
+
}
|
747 |
+
}
|
748 |
+
|
749 |
+
RETURN_TYPES = ()
|
750 |
+
FUNCTION = "save_model"
|
751 |
+
|
752 |
+
OUTPUT_NODE = True
|
753 |
+
|
754 |
+
CATEGORY = "🌌 ReActor"
|
755 |
+
|
756 |
+
def save_model(self, save_mode, face_model_name, select_face_index, image=None, face_model=None, det_size=(640, 640)):
|
757 |
+
if save_mode and image is not None:
|
758 |
+
source = tensor_to_pil(image)
|
759 |
+
source = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
|
760 |
+
apply_patch(1)
|
761 |
+
logger.status("Building Face Model...")
|
762 |
+
face_model_raw = analyze_faces(source, det_size)
|
763 |
+
if len(face_model_raw) == 0:
|
764 |
+
det_size_half = half_det_size(det_size)
|
765 |
+
face_model_raw = analyze_faces(source, det_size_half)
|
766 |
+
try:
|
767 |
+
face_model = face_model_raw[select_face_index]
|
768 |
+
except:
|
769 |
+
logger.error("No face(s) found")
|
770 |
+
return face_model_name
|
771 |
+
logger.status("--Done!--")
|
772 |
+
if save_mode and (face_model != "none" or face_model is not None):
|
773 |
+
face_model_path = os.path.join(self.output_dir, face_model_name + ".safetensors")
|
774 |
+
save_face_model(face_model,face_model_path)
|
775 |
+
if image is None and face_model is None:
|
776 |
+
logger.error("Please provide `face_model` or `image`")
|
777 |
+
return face_model_name
|
778 |
+
|
779 |
+
|
780 |
+
class RestoreFace:
|
781 |
+
@classmethod
|
782 |
+
def INPUT_TYPES(s):
|
783 |
+
return {
|
784 |
+
"required": {
|
785 |
+
"image": ("IMAGE",),
|
786 |
+
"facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
|
787 |
+
"model": (get_model_names(get_restorers),),
|
788 |
+
"visibility": ("FLOAT", {"default": 1, "min": 0.0, "max": 1, "step": 0.05}),
|
789 |
+
"codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
|
790 |
+
},
|
791 |
+
}
|
792 |
+
|
793 |
+
RETURN_TYPES = ("IMAGE",)
|
794 |
+
FUNCTION = "execute"
|
795 |
+
CATEGORY = "🌌 ReActor"
|
796 |
+
|
797 |
+
# def __init__(self):
|
798 |
+
# self.face_helper = None
|
799 |
+
# self.face_size = 512
|
800 |
+
|
801 |
+
def execute(self, image, model, visibility, codeformer_weight, facedetection):
|
802 |
+
result = reactor.restore_face(self,image,model,visibility,codeformer_weight,facedetection)
|
803 |
+
return (result,)
|
804 |
+
|
805 |
+
|
806 |
+
class MaskHelper:
|
807 |
+
def __init__(self):
|
808 |
+
# self.threshold = 0.5
|
809 |
+
# self.dilation = 10
|
810 |
+
# self.crop_factor = 3.0
|
811 |
+
# self.drop_size = 1
|
812 |
+
self.labels = "all"
|
813 |
+
self.detailer_hook = None
|
814 |
+
self.device_mode = "AUTO"
|
815 |
+
self.detection_hint = "center-1"
|
816 |
+
# self.sam_dilation = 0
|
817 |
+
# self.sam_threshold = 0.93
|
818 |
+
# self.bbox_expansion = 0
|
819 |
+
# self.mask_hint_threshold = 0.7
|
820 |
+
# self.mask_hint_use_negative = "False"
|
821 |
+
# self.force_resize_width = 0
|
822 |
+
# self.force_resize_height = 0
|
823 |
+
# self.resize_behavior = "source_size"
|
824 |
+
|
825 |
+
@classmethod
|
826 |
+
def INPUT_TYPES(s):
|
827 |
+
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("ultralytics_bbox")]
|
828 |
+
segms = ["segm/"+x for x in folder_paths.get_filename_list("ultralytics_segm")]
|
829 |
+
sam_models = [x for x in folder_paths.get_filename_list("sams") if 'hq' not in x]
|
830 |
+
return {
|
831 |
+
"required": {
|
832 |
+
"image": ("IMAGE",),
|
833 |
+
"swapped_image": ("IMAGE",),
|
834 |
+
"bbox_model_name": (bboxs + segms, ),
|
835 |
+
"bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
836 |
+
"bbox_dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}),
|
837 |
+
"bbox_crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
|
838 |
+
"bbox_drop_size": ("INT", {"min": 1, "max": 8192, "step": 1, "default": 10}),
|
839 |
+
"sam_model_name": (sam_models, ),
|
840 |
+
"sam_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
|
841 |
+
"sam_threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}),
|
842 |
+
"bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
|
843 |
+
"mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
|
844 |
+
"mask_hint_use_negative": (["False", "Small", "Outter"], ),
|
845 |
+
"morphology_operation": (["dilate", "erode", "open", "close"],),
|
846 |
+
"morphology_distance": ("INT", {"default": 0, "min": 0, "max": 128, "step": 1}),
|
847 |
+
"blur_radius": ("INT", {"default": 9, "min": 0, "max": 48, "step": 1}),
|
848 |
+
"sigma_factor": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 3., "step": 0.01}),
|
849 |
+
},
|
850 |
+
"optional": {
|
851 |
+
"mask_optional": ("MASK",),
|
852 |
+
}
|
853 |
+
}
|
854 |
+
|
855 |
+
RETURN_TYPES = ("IMAGE","MASK","IMAGE","IMAGE")
|
856 |
+
RETURN_NAMES = ("IMAGE","MASK","MASK_PREVIEW","SWAPPED_FACE")
|
857 |
+
FUNCTION = "execute"
|
858 |
+
CATEGORY = "🌌 ReActor"
|
859 |
+
|
860 |
+
def execute(self, image, swapped_image, bbox_model_name, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, sam_model_name, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative, morphology_operation, morphology_distance, blur_radius, sigma_factor, mask_optional=None):
|
861 |
+
|
862 |
+
# images = [image[i:i + 1, ...] for i in range(image.shape[0])]
|
863 |
+
|
864 |
+
images = image
|
865 |
+
|
866 |
+
if mask_optional is None:
|
867 |
+
|
868 |
+
bbox_model_path = folder_paths.get_full_path("ultralytics", bbox_model_name)
|
869 |
+
bbox_model = subcore.load_yolo(bbox_model_path)
|
870 |
+
bbox_detector = subcore.UltraBBoxDetector(bbox_model)
|
871 |
+
|
872 |
+
segs = bbox_detector.detect(images, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, self.detailer_hook)
|
873 |
+
|
874 |
+
if isinstance(self.labels, list):
|
875 |
+
self.labels = str(self.labels[0])
|
876 |
+
|
877 |
+
if self.labels is not None and self.labels != '':
|
878 |
+
self.labels = self.labels.split(',')
|
879 |
+
if len(self.labels) > 0:
|
880 |
+
segs, _ = masking_segs.filter(segs, self.labels)
|
881 |
+
# segs, _ = masking_segs.filter(segs, "all")
|
882 |
+
|
883 |
+
sam_modelname = folder_paths.get_full_path("sams", sam_model_name)
|
884 |
+
|
885 |
+
if 'vit_h' in sam_model_name:
|
886 |
+
model_kind = 'vit_h'
|
887 |
+
elif 'vit_l' in sam_model_name:
|
888 |
+
model_kind = 'vit_l'
|
889 |
+
else:
|
890 |
+
model_kind = 'vit_b'
|
891 |
+
|
892 |
+
sam = sam_model_registry[model_kind](checkpoint=sam_modelname)
|
893 |
+
size = os.path.getsize(sam_modelname)
|
894 |
+
sam.safe_to = core.SafeToGPU(size)
|
895 |
+
|
896 |
+
device = model_management.get_torch_device()
|
897 |
+
|
898 |
+
sam.safe_to.to_device(sam, device)
|
899 |
+
|
900 |
+
sam.is_auto_mode = self.device_mode == "AUTO"
|
901 |
+
|
902 |
+
combined_mask, _ = core.make_sam_mask_segmented(sam, segs, images, self.detection_hint, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative)
|
903 |
+
|
904 |
+
else:
|
905 |
+
combined_mask = mask_optional
|
906 |
+
|
907 |
+
# *** MASK TO IMAGE ***:
|
908 |
+
|
909 |
+
mask_image = combined_mask.reshape((-1, 1, combined_mask.shape[-2], combined_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
910 |
+
|
911 |
+
# *** MASK MORPH ***:
|
912 |
+
|
913 |
+
mask_image = core.tensor2mask(mask_image)
|
914 |
+
|
915 |
+
if morphology_operation == "dilate":
|
916 |
+
mask_image = self.dilate(mask_image, morphology_distance)
|
917 |
+
elif morphology_operation == "erode":
|
918 |
+
mask_image = self.erode(mask_image, morphology_distance)
|
919 |
+
elif morphology_operation == "open":
|
920 |
+
mask_image = self.erode(mask_image, morphology_distance)
|
921 |
+
mask_image = self.dilate(mask_image, morphology_distance)
|
922 |
+
elif morphology_operation == "close":
|
923 |
+
mask_image = self.dilate(mask_image, morphology_distance)
|
924 |
+
mask_image = self.erode(mask_image, morphology_distance)
|
925 |
+
|
926 |
+
# *** MASK BLUR ***:
|
927 |
+
|
928 |
+
if len(mask_image.size()) == 3:
|
929 |
+
mask_image = mask_image.unsqueeze(3)
|
930 |
+
|
931 |
+
mask_image = mask_image.permute(0, 3, 1, 2)
|
932 |
+
kernel_size = blur_radius * 2 + 1
|
933 |
+
sigma = sigma_factor * (0.6 * blur_radius - 0.3)
|
934 |
+
mask_image_final = self.gaussian_blur(mask_image, kernel_size, sigma).permute(0, 2, 3, 1)
|
935 |
+
if mask_image_final.size()[3] == 1:
|
936 |
+
mask_image_final = mask_image_final[:, :, :, 0]
|
937 |
+
|
938 |
+
# *** CUT BY MASK ***:
|
939 |
+
|
940 |
+
if len(swapped_image.shape) < 4:
|
941 |
+
C = 1
|
942 |
+
else:
|
943 |
+
C = swapped_image.shape[3]
|
944 |
+
|
945 |
+
# We operate on RGBA to keep the code clean and then convert back after
|
946 |
+
swapped_image = core.tensor2rgba(swapped_image)
|
947 |
+
mask = core.tensor2mask(mask_image_final)
|
948 |
+
|
949 |
+
# Scale the mask to be a matching size if it isn't
|
950 |
+
B, H, W, _ = swapped_image.shape
|
951 |
+
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:]
|
952 |
+
MB, _, _ = mask.shape
|
953 |
+
|
954 |
+
if MB < B:
|
955 |
+
assert(B % MB == 0)
|
956 |
+
mask = mask.repeat(B // MB, 1, 1)
|
957 |
+
|
958 |
+
# masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end
|
959 |
+
is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, H * W]), dim=1).values, 0.)
|
960 |
+
mask[is_empty,0,0] = 1.
|
961 |
+
boxes = masks_to_boxes(mask)
|
962 |
+
mask[is_empty,0,0] = 0.
|
963 |
+
|
964 |
+
min_x = boxes[:,0]
|
965 |
+
min_y = boxes[:,1]
|
966 |
+
max_x = boxes[:,2]
|
967 |
+
max_y = boxes[:,3]
|
968 |
+
|
969 |
+
width = max_x - min_x + 1
|
970 |
+
height = max_y - min_y + 1
|
971 |
+
|
972 |
+
use_width = int(torch.max(width).item())
|
973 |
+
use_height = int(torch.max(height).item())
|
974 |
+
|
975 |
+
# if self.force_resize_width > 0:
|
976 |
+
# use_width = self.force_resize_width
|
977 |
+
|
978 |
+
# if self.force_resize_height > 0:
|
979 |
+
# use_height = self.force_resize_height
|
980 |
+
|
981 |
+
alpha_mask = torch.ones((B, H, W, 4))
|
982 |
+
alpha_mask[:,:,:,3] = mask
|
983 |
+
|
984 |
+
swapped_image = swapped_image * alpha_mask
|
985 |
+
|
986 |
+
cutted_image = torch.zeros((B, use_height, use_width, 4))
|
987 |
+
for i in range(0, B):
|
988 |
+
if not is_empty[i]:
|
989 |
+
ymin = int(min_y[i].item())
|
990 |
+
ymax = int(max_y[i].item())
|
991 |
+
xmin = int(min_x[i].item())
|
992 |
+
xmax = int(max_x[i].item())
|
993 |
+
single = (swapped_image[i, ymin:ymax+1, xmin:xmax+1,:]).unsqueeze(0)
|
994 |
+
resized = torch.nn.functional.interpolate(single.permute(0, 3, 1, 2), size=(use_height, use_width), mode='bicubic').permute(0, 2, 3, 1)
|
995 |
+
cutted_image[i] = resized[0]
|
996 |
+
|
997 |
+
# Preserve our type unless we were previously RGB and added non-opaque alpha due to the mask size
|
998 |
+
if C == 1:
|
999 |
+
cutted_image = core.tensor2mask(cutted_image)
|
1000 |
+
elif C == 3 and torch.min(cutted_image[:,:,:,3]) == 1:
|
1001 |
+
cutted_image = core.tensor2rgb(cutted_image)
|
1002 |
+
|
1003 |
+
# *** PASTE BY MASK ***:
|
1004 |
+
|
1005 |
+
image_base = core.tensor2rgba(images)
|
1006 |
+
image_to_paste = core.tensor2rgba(cutted_image)
|
1007 |
+
mask = core.tensor2mask(mask_image_final)
|
1008 |
+
|
1009 |
+
# Scale the mask to be a matching size if it isn't
|
1010 |
+
B, H, W, C = image_base.shape
|
1011 |
+
MB = mask.shape[0]
|
1012 |
+
PB = image_to_paste.shape[0]
|
1013 |
+
|
1014 |
+
if B < PB:
|
1015 |
+
assert(PB % B == 0)
|
1016 |
+
image_base = image_base.repeat(PB // B, 1, 1, 1)
|
1017 |
+
B, H, W, C = image_base.shape
|
1018 |
+
if MB < B:
|
1019 |
+
assert(B % MB == 0)
|
1020 |
+
mask = mask.repeat(B // MB, 1, 1)
|
1021 |
+
elif B < MB:
|
1022 |
+
assert(MB % B == 0)
|
1023 |
+
image_base = image_base.repeat(MB // B, 1, 1, 1)
|
1024 |
+
if PB < B:
|
1025 |
+
assert(B % PB == 0)
|
1026 |
+
image_to_paste = image_to_paste.repeat(B // PB, 1, 1, 1)
|
1027 |
+
|
1028 |
+
mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:]
|
1029 |
+
MB, MH, MW = mask.shape
|
1030 |
+
|
1031 |
+
# masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end
|
1032 |
+
is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, MH * MW]), dim=1).values, 0.)
|
1033 |
+
mask[is_empty,0,0] = 1.
|
1034 |
+
boxes = masks_to_boxes(mask)
|
1035 |
+
mask[is_empty,0,0] = 0.
|
1036 |
+
|
1037 |
+
min_x = boxes[:,0]
|
1038 |
+
min_y = boxes[:,1]
|
1039 |
+
max_x = boxes[:,2]
|
1040 |
+
max_y = boxes[:,3]
|
1041 |
+
mid_x = (min_x + max_x) / 2
|
1042 |
+
mid_y = (min_y + max_y) / 2
|
1043 |
+
|
1044 |
+
target_width = max_x - min_x + 1
|
1045 |
+
target_height = max_y - min_y + 1
|
1046 |
+
|
1047 |
+
result = image_base.detach().clone()
|
1048 |
+
face_segment = mask_image_final
|
1049 |
+
|
1050 |
+
for i in range(0, MB):
|
1051 |
+
if is_empty[i]:
|
1052 |
+
continue
|
1053 |
+
else:
|
1054 |
+
image_index = i
|
1055 |
+
source_size = image_to_paste.size()
|
1056 |
+
SB, SH, SW, _ = image_to_paste.shape
|
1057 |
+
|
1058 |
+
# Figure out the desired size
|
1059 |
+
width = int(target_width[i].item())
|
1060 |
+
height = int(target_height[i].item())
|
1061 |
+
# if self.resize_behavior == "keep_ratio_fill":
|
1062 |
+
# target_ratio = width / height
|
1063 |
+
# actual_ratio = SW / SH
|
1064 |
+
# if actual_ratio > target_ratio:
|
1065 |
+
# width = int(height * actual_ratio)
|
1066 |
+
# elif actual_ratio < target_ratio:
|
1067 |
+
# height = int(width / actual_ratio)
|
1068 |
+
# elif self.resize_behavior == "keep_ratio_fit":
|
1069 |
+
# target_ratio = width / height
|
1070 |
+
# actual_ratio = SW / SH
|
1071 |
+
# if actual_ratio > target_ratio:
|
1072 |
+
# height = int(width / actual_ratio)
|
1073 |
+
# elif actual_ratio < target_ratio:
|
1074 |
+
# width = int(height * actual_ratio)
|
1075 |
+
# elif self.resize_behavior == "source_size" or self.resize_behavior == "source_size_unmasked":
|
1076 |
+
|
1077 |
+
width = SW
|
1078 |
+
height = SH
|
1079 |
+
|
1080 |
+
# Resize the image we're pasting if needed
|
1081 |
+
resized_image = image_to_paste[i].unsqueeze(0)
|
1082 |
+
# if SH != height or SW != width:
|
1083 |
+
# resized_image = torch.nn.functional.interpolate(resized_image.permute(0, 3, 1, 2), size=(height,width), mode='bicubic').permute(0, 2, 3, 1)
|
1084 |
+
|
1085 |
+
pasting = torch.ones([H, W, C])
|
1086 |
+
ymid = float(mid_y[i].item())
|
1087 |
+
ymin = int(math.floor(ymid - height / 2)) + 1
|
1088 |
+
ymax = int(math.floor(ymid + height / 2)) + 1
|
1089 |
+
xmid = float(mid_x[i].item())
|
1090 |
+
xmin = int(math.floor(xmid - width / 2)) + 1
|
1091 |
+
xmax = int(math.floor(xmid + width / 2)) + 1
|
1092 |
+
|
1093 |
+
_, source_ymax, source_xmax, _ = resized_image.shape
|
1094 |
+
source_ymin, source_xmin = 0, 0
|
1095 |
+
|
1096 |
+
if xmin < 0:
|
1097 |
+
source_xmin = abs(xmin)
|
1098 |
+
xmin = 0
|
1099 |
+
if ymin < 0:
|
1100 |
+
source_ymin = abs(ymin)
|
1101 |
+
ymin = 0
|
1102 |
+
if xmax > W:
|
1103 |
+
source_xmax -= (xmax - W)
|
1104 |
+
xmax = W
|
1105 |
+
if ymax > H:
|
1106 |
+
source_ymax -= (ymax - H)
|
1107 |
+
ymax = H
|
1108 |
+
|
1109 |
+
pasting[ymin:ymax, xmin:xmax, :] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, :]
|
1110 |
+
pasting[:, :, 3] = 1.
|
1111 |
+
|
1112 |
+
pasting_alpha = torch.zeros([H, W])
|
1113 |
+
pasting_alpha[ymin:ymax, xmin:xmax] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, 3]
|
1114 |
+
|
1115 |
+
# if self.resize_behavior == "keep_ratio_fill" or self.resize_behavior == "source_size_unmasked":
|
1116 |
+
# # If we explicitly want to fill the area, we are ok with extending outside
|
1117 |
+
# paste_mask = pasting_alpha.unsqueeze(2).repeat(1, 1, 4)
|
1118 |
+
# else:
|
1119 |
+
# paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4)
|
1120 |
+
paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4)
|
1121 |
+
result[image_index] = pasting * paste_mask + result[image_index] * (1. - paste_mask)
|
1122 |
+
|
1123 |
+
face_segment = result
|
1124 |
+
|
1125 |
+
face_segment[...,3] = mask[i]
|
1126 |
+
|
1127 |
+
result = rgba2rgb_tensor(result)
|
1128 |
+
|
1129 |
+
return (result,combined_mask,mask_image_final,face_segment,)
|
1130 |
+
|
1131 |
+
def gaussian_blur(self, image, kernel_size, sigma):
|
1132 |
+
kernel = torch.Tensor(kernel_size, kernel_size).to(device=image.device)
|
1133 |
+
center = kernel_size // 2
|
1134 |
+
variance = sigma**2
|
1135 |
+
for i in range(kernel_size):
|
1136 |
+
for j in range(kernel_size):
|
1137 |
+
x = i - center
|
1138 |
+
y = j - center
|
1139 |
+
kernel[i, j] = math.exp(-(x**2 + y**2)/(2*variance))
|
1140 |
+
kernel /= kernel.sum()
|
1141 |
+
|
1142 |
+
# Pad the input tensor
|
1143 |
+
padding = (kernel_size - 1) // 2
|
1144 |
+
input_pad = torch.nn.functional.pad(image, (padding, padding, padding, padding), mode='reflect')
|
1145 |
+
|
1146 |
+
# Reshape the padded input tensor for batched convolution
|
1147 |
+
batch_size, num_channels, height, width = image.shape
|
1148 |
+
input_reshaped = input_pad.reshape(batch_size*num_channels, 1, height+padding*2, width+padding*2)
|
1149 |
+
|
1150 |
+
# Perform batched convolution with the Gaussian kernel
|
1151 |
+
output_reshaped = torch.nn.functional.conv2d(input_reshaped, kernel.unsqueeze(0).unsqueeze(0))
|
1152 |
+
|
1153 |
+
# Reshape the output tensor to its original shape
|
1154 |
+
output_tensor = output_reshaped.reshape(batch_size, num_channels, height, width)
|
1155 |
+
|
1156 |
+
return output_tensor
|
1157 |
+
|
1158 |
+
def erode(self, image, distance):
|
1159 |
+
return 1. - self.dilate(1. - image, distance)
|
1160 |
+
|
1161 |
+
def dilate(self, image, distance):
|
1162 |
+
kernel_size = 1 + distance * 2
|
1163 |
+
# Add the channels dimension
|
1164 |
+
image = image.unsqueeze(1)
|
1165 |
+
out = torchfn.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=kernel_size // 2).squeeze(1)
|
1166 |
+
return out
|
1167 |
+
|
1168 |
+
|
1169 |
+
class ImageDublicator:
|
1170 |
+
@classmethod
|
1171 |
+
def INPUT_TYPES(s):
|
1172 |
+
return {
|
1173 |
+
"required": {
|
1174 |
+
"image": ("IMAGE",),
|
1175 |
+
"count": ("INT", {"default": 1, "min": 0}),
|
1176 |
+
},
|
1177 |
+
}
|
1178 |
+
|
1179 |
+
RETURN_TYPES = ("IMAGE",)
|
1180 |
+
RETURN_NAMES = ("IMAGES",)
|
1181 |
+
OUTPUT_IS_LIST = (True,)
|
1182 |
+
FUNCTION = "execute"
|
1183 |
+
CATEGORY = "🌌 ReActor"
|
1184 |
+
|
1185 |
+
def execute(self, image, count):
|
1186 |
+
images = [image for i in range(count)]
|
1187 |
+
return (images,)
|
1188 |
+
|
1189 |
+
|
1190 |
+
class ImageRGBA2RGB:
|
1191 |
+
@classmethod
|
1192 |
+
def INPUT_TYPES(s):
|
1193 |
+
return {
|
1194 |
+
"required": {
|
1195 |
+
"image": ("IMAGE",),
|
1196 |
+
},
|
1197 |
+
}
|
1198 |
+
|
1199 |
+
RETURN_TYPES = ("IMAGE",)
|
1200 |
+
FUNCTION = "execute"
|
1201 |
+
CATEGORY = "🌌 ReActor"
|
1202 |
+
|
1203 |
+
def execute(self, image):
|
1204 |
+
out = rgba2rgb_tensor(image)
|
1205 |
+
return (out,)
|
1206 |
+
|
1207 |
+
|
1208 |
+
class MakeFaceModelBatch:
|
1209 |
+
@classmethod
|
1210 |
+
def INPUT_TYPES(s):
|
1211 |
+
return {
|
1212 |
+
"required": {
|
1213 |
+
"face_model1": ("FACE_MODEL",),
|
1214 |
+
},
|
1215 |
+
"optional": {
|
1216 |
+
"face_model2": ("FACE_MODEL",),
|
1217 |
+
"face_model3": ("FACE_MODEL",),
|
1218 |
+
"face_model4": ("FACE_MODEL",),
|
1219 |
+
"face_model5": ("FACE_MODEL",),
|
1220 |
+
"face_model6": ("FACE_MODEL",),
|
1221 |
+
"face_model7": ("FACE_MODEL",),
|
1222 |
+
"face_model8": ("FACE_MODEL",),
|
1223 |
+
"face_model9": ("FACE_MODEL",),
|
1224 |
+
"face_model10": ("FACE_MODEL",),
|
1225 |
+
},
|
1226 |
+
}
|
1227 |
+
|
1228 |
+
RETURN_TYPES = ("FACE_MODEL",)
|
1229 |
+
RETURN_NAMES = ("FACE_MODELS",)
|
1230 |
+
FUNCTION = "execute"
|
1231 |
+
|
1232 |
+
CATEGORY = "🌌 ReActor"
|
1233 |
+
|
1234 |
+
def execute(self, **kwargs):
|
1235 |
+
if len(kwargs) > 0:
|
1236 |
+
face_models = [value for value in kwargs.values()]
|
1237 |
+
return (face_models,)
|
1238 |
+
else:
|
1239 |
+
logger.error("Please provide at least 1 `face_model`")
|
1240 |
+
return (None,)
|
1241 |
+
|
1242 |
+
|
1243 |
+
class ReActorOptions:
|
1244 |
+
@classmethod
|
1245 |
+
def INPUT_TYPES(s):
|
1246 |
+
return {
|
1247 |
+
"required": {
|
1248 |
+
"input_faces_order": (
|
1249 |
+
["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"}
|
1250 |
+
),
|
1251 |
+
"input_faces_index": ("STRING", {"default": "0"}),
|
1252 |
+
"detect_gender_input": (["no","female","male"], {"default": "no"}),
|
1253 |
+
"source_faces_order": (
|
1254 |
+
["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"}
|
1255 |
+
),
|
1256 |
+
"source_faces_index": ("STRING", {"default": "0"}),
|
1257 |
+
"detect_gender_source": (["no","female","male"], {"default": "no"}),
|
1258 |
+
"console_log_level": ([0, 1, 2], {"default": 1}),
|
1259 |
+
}
|
1260 |
+
}
|
1261 |
+
|
1262 |
+
RETURN_TYPES = ("OPTIONS",)
|
1263 |
+
FUNCTION = "execute"
|
1264 |
+
CATEGORY = "🌌 ReActor"
|
1265 |
+
|
1266 |
+
def execute(self,input_faces_order, input_faces_index, detect_gender_input, source_faces_order, source_faces_index, detect_gender_source, console_log_level):
|
1267 |
+
options: dict = {
|
1268 |
+
"input_faces_order": input_faces_order,
|
1269 |
+
"input_faces_index": input_faces_index,
|
1270 |
+
"detect_gender_input": detect_gender_input,
|
1271 |
+
"source_faces_order": source_faces_order,
|
1272 |
+
"source_faces_index": source_faces_index,
|
1273 |
+
"detect_gender_source": detect_gender_source,
|
1274 |
+
"console_log_level": console_log_level,
|
1275 |
+
}
|
1276 |
+
return (options, )
|
1277 |
+
|
1278 |
+
|
1279 |
+
class ReActorFaceBoost:
|
1280 |
+
@classmethod
|
1281 |
+
def INPUT_TYPES(s):
|
1282 |
+
return {
|
1283 |
+
"required": {
|
1284 |
+
"enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
|
1285 |
+
"boost_model": (get_model_names(get_restorers),),
|
1286 |
+
"interpolation": (["Nearest","Bilinear","Bicubic","Lanczos"], {"default": "Bicubic"}),
|
1287 |
+
"visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
|
1288 |
+
"codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
|
1289 |
+
"restore_with_main_after": ("BOOLEAN", {"default": False}),
|
1290 |
+
}
|
1291 |
+
}
|
1292 |
+
|
1293 |
+
RETURN_TYPES = ("FACE_BOOST",)
|
1294 |
+
FUNCTION = "execute"
|
1295 |
+
CATEGORY = "🌌 ReActor"
|
1296 |
+
|
1297 |
+
def execute(self,enabled,boost_model,interpolation,visibility,codeformer_weight,restore_with_main_after):
|
1298 |
+
face_boost: dict = {
|
1299 |
+
"enabled": enabled,
|
1300 |
+
"boost_model": boost_model,
|
1301 |
+
"interpolation": interpolation,
|
1302 |
+
"visibility": visibility,
|
1303 |
+
"codeformer_weight": codeformer_weight,
|
1304 |
+
"restore_with_main_after": restore_with_main_after,
|
1305 |
+
}
|
1306 |
+
return (face_boost, )
|
1307 |
+
|
1308 |
+
class ReActorUnload:
|
1309 |
+
@classmethod
|
1310 |
+
def INPUT_TYPES(s):
|
1311 |
+
return {
|
1312 |
+
"required": {
|
1313 |
+
"trigger": ("IMAGE", ),
|
1314 |
+
},
|
1315 |
+
}
|
1316 |
+
|
1317 |
+
RETURN_TYPES = ("IMAGE",)
|
1318 |
+
FUNCTION = "execute"
|
1319 |
+
CATEGORY = "🌌 ReActor"
|
1320 |
+
|
1321 |
+
def execute(self, trigger):
|
1322 |
+
unload_all_models()
|
1323 |
+
return (trigger,)
|
1324 |
+
|
1325 |
+
|
1326 |
+
NODE_CLASS_MAPPINGS = {
|
1327 |
+
# --- MAIN NODES ---
|
1328 |
+
"ReActorFaceSwap": reactor,
|
1329 |
+
"ReActorFaceSwapOpt": ReActorPlusOpt,
|
1330 |
+
"ReActorOptions": ReActorOptions,
|
1331 |
+
"ReActorFaceBoost": ReActorFaceBoost,
|
1332 |
+
"ReActorMaskHelper": MaskHelper,
|
1333 |
+
"ReActorSetWeight": ReActorWeight,
|
1334 |
+
# --- Operations with Face Models ---
|
1335 |
+
"ReActorSaveFaceModel": SaveFaceModel,
|
1336 |
+
"ReActorLoadFaceModel": LoadFaceModel,
|
1337 |
+
"ReActorBuildFaceModel": BuildFaceModel,
|
1338 |
+
"ReActorMakeFaceModelBatch": MakeFaceModelBatch,
|
1339 |
+
# --- Additional Nodes ---
|
1340 |
+
"ReActorRestoreFace": RestoreFace,
|
1341 |
+
"ReActorImageDublicator": ImageDublicator,
|
1342 |
+
"ImageRGBA2RGB": ImageRGBA2RGB,
|
1343 |
+
"ReActorUnload": ReActorUnload,
|
1344 |
+
}
|
1345 |
+
|
1346 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
1347 |
+
# --- MAIN NODES ---
|
1348 |
+
"ReActorFaceSwap": "ReActor 🌌 Fast Face Swap",
|
1349 |
+
"ReActorFaceSwapOpt": "ReActor 🌌 Fast Face Swap [OPTIONS]",
|
1350 |
+
"ReActorOptions": "ReActor 🌌 Options",
|
1351 |
+
"ReActorFaceBoost": "ReActor 🌌 Face Booster",
|
1352 |
+
"ReActorMaskHelper": "ReActor 🌌 Masking Helper",
|
1353 |
+
"ReActorSetWeight": "ReActor 🌌 Set Face Swap Weight",
|
1354 |
+
# --- Operations with Face Models ---
|
1355 |
+
"ReActorSaveFaceModel": "Save Face Model 🌌 ReActor",
|
1356 |
+
"ReActorLoadFaceModel": "Load Face Model 🌌 ReActor",
|
1357 |
+
"ReActorBuildFaceModel": "Build Blended Face Model 🌌 ReActor",
|
1358 |
+
"ReActorMakeFaceModelBatch": "Make Face Model Batch 🌌 ReActor",
|
1359 |
+
# --- Additional Nodes ---
|
1360 |
+
"ReActorRestoreFace": "Restore Face 🌌 ReActor",
|
1361 |
+
"ReActorImageDublicator": "Image Dublicator (List) 🌌 ReActor",
|
1362 |
+
"ImageRGBA2RGB": "Convert RGBA to RGB 🌌 ReActor",
|
1363 |
+
"ReActorUnload": "Unload ReActor Models 🌌 ReActor",
|
1364 |
+
}
|
custom_nodes/ComfyUI-ReActor/pyproject.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "comfyui-reactor"
|
3 |
+
description = "(SFW-Friendly) The Fast and Simple Face Swap Extension Node for ComfyUI, based on ReActor SD-WebUI Face Swap Extension"
|
4 |
+
version = "0.6.0-a1"
|
5 |
+
license = { file = "LICENSE" }
|
6 |
+
dependencies = ["insightface==0.7.3", "onnx>=1.14.0", "opencv-python>=4.7.0.72", "numpy==1.26.3", "segment_anything", "albumentations>=1.4.16", "ultralytics"]
|
7 |
+
|
8 |
+
[project.urls]
|
9 |
+
Repository = "https://github.com/Gourieff/ComfyUI-ReActor"
|
10 |
+
# Used by Comfy Registry https://comfyregistry.org
|
11 |
+
|
12 |
+
[tool.comfy]
|
13 |
+
PublisherId = "gourieff"
|
14 |
+
DisplayName = "ComfyUI-ReActor"
|
15 |
+
Icon = ""
|
custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/xinntao/BasicSR
|
2 |
+
# flake8: noqa
|
3 |
+
from .archs import *
|
4 |
+
from .data import *
|
5 |
+
from .losses import *
|
6 |
+
from .metrics import *
|
7 |
+
from .models import *
|
8 |
+
from .ops import *
|
9 |
+
from .test import *
|
10 |
+
from .train import *
|
11 |
+
from .utils import *
|
12 |
+
from .version import __gitsha__, __version__
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
from copy import deepcopy
|
3 |
+
from os import path as osp
|
4 |
+
|
5 |
+
from r_basicsr.utils import get_root_logger, scandir
|
6 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
__all__ = ['build_network']
|
9 |
+
|
10 |
+
# automatically scan and import arch modules for registry
|
11 |
+
# scan all the files under the 'archs' folder and collect files ending with
|
12 |
+
# '_arch.py'
|
13 |
+
arch_folder = osp.dirname(osp.abspath(__file__))
|
14 |
+
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
|
15 |
+
# import all the arch modules
|
16 |
+
_arch_modules = [importlib.import_module(f'r_basicsr.archs.{file_name}') for file_name in arch_filenames]
|
17 |
+
|
18 |
+
|
19 |
+
def build_network(opt):
|
20 |
+
opt = deepcopy(opt)
|
21 |
+
network_type = opt.pop('type')
|
22 |
+
net = ARCH_REGISTRY.get(network_type)(**opt)
|
23 |
+
logger = get_root_logger()
|
24 |
+
logger.info(f'Network [{net.__class__.__name__}] is created.')
|
25 |
+
return net
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections.abc
|
2 |
+
import math
|
3 |
+
import torch
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
try:
|
7 |
+
from distutils.version import LooseVersion
|
8 |
+
except:
|
9 |
+
from packaging.version import Version
|
10 |
+
LooseVersion = Version
|
11 |
+
from itertools import repeat
|
12 |
+
from torch import nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
from torch.nn import init as init
|
15 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
16 |
+
|
17 |
+
from r_basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
|
18 |
+
from r_basicsr.utils import get_root_logger
|
19 |
+
|
20 |
+
|
21 |
+
@torch.no_grad()
|
22 |
+
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
|
23 |
+
"""Initialize network weights.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
module_list (list[nn.Module] | nn.Module): Modules to be initialized.
|
27 |
+
scale (float): Scale initialized weights, especially for residual
|
28 |
+
blocks. Default: 1.
|
29 |
+
bias_fill (float): The value to fill bias. Default: 0
|
30 |
+
kwargs (dict): Other arguments for initialization function.
|
31 |
+
"""
|
32 |
+
if not isinstance(module_list, list):
|
33 |
+
module_list = [module_list]
|
34 |
+
for module in module_list:
|
35 |
+
for m in module.modules():
|
36 |
+
if isinstance(m, nn.Conv2d):
|
37 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
38 |
+
m.weight.data *= scale
|
39 |
+
if m.bias is not None:
|
40 |
+
m.bias.data.fill_(bias_fill)
|
41 |
+
elif isinstance(m, nn.Linear):
|
42 |
+
init.kaiming_normal_(m.weight, **kwargs)
|
43 |
+
m.weight.data *= scale
|
44 |
+
if m.bias is not None:
|
45 |
+
m.bias.data.fill_(bias_fill)
|
46 |
+
elif isinstance(m, _BatchNorm):
|
47 |
+
init.constant_(m.weight, 1)
|
48 |
+
if m.bias is not None:
|
49 |
+
m.bias.data.fill_(bias_fill)
|
50 |
+
|
51 |
+
|
52 |
+
def make_layer(basic_block, num_basic_block, **kwarg):
|
53 |
+
"""Make layers by stacking the same blocks.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
basic_block (nn.module): nn.module class for basic block.
|
57 |
+
num_basic_block (int): number of blocks.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
nn.Sequential: Stacked blocks in nn.Sequential.
|
61 |
+
"""
|
62 |
+
layers = []
|
63 |
+
for _ in range(num_basic_block):
|
64 |
+
layers.append(basic_block(**kwarg))
|
65 |
+
return nn.Sequential(*layers)
|
66 |
+
|
67 |
+
|
68 |
+
class ResidualBlockNoBN(nn.Module):
|
69 |
+
"""Residual block without BN.
|
70 |
+
|
71 |
+
It has a style of:
|
72 |
+
---Conv-ReLU-Conv-+-
|
73 |
+
|________________|
|
74 |
+
|
75 |
+
Args:
|
76 |
+
num_feat (int): Channel number of intermediate features.
|
77 |
+
Default: 64.
|
78 |
+
res_scale (float): Residual scale. Default: 1.
|
79 |
+
pytorch_init (bool): If set to True, use pytorch default init,
|
80 |
+
otherwise, use default_init_weights. Default: False.
|
81 |
+
"""
|
82 |
+
|
83 |
+
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
|
84 |
+
super(ResidualBlockNoBN, self).__init__()
|
85 |
+
self.res_scale = res_scale
|
86 |
+
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
87 |
+
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
|
88 |
+
self.relu = nn.ReLU(inplace=True)
|
89 |
+
|
90 |
+
if not pytorch_init:
|
91 |
+
default_init_weights([self.conv1, self.conv2], 0.1)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
identity = x
|
95 |
+
out = self.conv2(self.relu(self.conv1(x)))
|
96 |
+
return identity + out * self.res_scale
|
97 |
+
|
98 |
+
|
99 |
+
class Upsample(nn.Sequential):
|
100 |
+
"""Upsample module.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
104 |
+
num_feat (int): Channel number of intermediate features.
|
105 |
+
"""
|
106 |
+
|
107 |
+
def __init__(self, scale, num_feat):
|
108 |
+
m = []
|
109 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
110 |
+
for _ in range(int(math.log(scale, 2))):
|
111 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
112 |
+
m.append(nn.PixelShuffle(2))
|
113 |
+
elif scale == 3:
|
114 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
115 |
+
m.append(nn.PixelShuffle(3))
|
116 |
+
else:
|
117 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
118 |
+
super(Upsample, self).__init__(*m)
|
119 |
+
|
120 |
+
|
121 |
+
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
|
122 |
+
"""Warp an image or feature map with optical flow.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
126 |
+
flow (Tensor): Tensor with size (n, h, w, 2), normal value.
|
127 |
+
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
|
128 |
+
padding_mode (str): 'zeros' or 'border' or 'reflection'.
|
129 |
+
Default: 'zeros'.
|
130 |
+
align_corners (bool): Before pytorch 1.3, the default value is
|
131 |
+
align_corners=True. After pytorch 1.3, the default value is
|
132 |
+
align_corners=False. Here, we use the True as default.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Tensor: Warped image or feature map.
|
136 |
+
"""
|
137 |
+
assert x.size()[-2:] == flow.size()[1:3]
|
138 |
+
_, _, h, w = x.size()
|
139 |
+
# create mesh grid
|
140 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
|
141 |
+
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
|
142 |
+
grid.requires_grad = False
|
143 |
+
|
144 |
+
vgrid = grid + flow
|
145 |
+
# scale grid to [-1,1]
|
146 |
+
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
|
147 |
+
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
|
148 |
+
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
|
149 |
+
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
|
150 |
+
|
151 |
+
# TODO, what if align_corners=False
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
|
156 |
+
"""Resize a flow according to ratio or shape.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
flow (Tensor): Precomputed flow. shape [N, 2, H, W].
|
160 |
+
size_type (str): 'ratio' or 'shape'.
|
161 |
+
sizes (list[int | float]): the ratio for resizing or the final output
|
162 |
+
shape.
|
163 |
+
1) The order of ratio should be [ratio_h, ratio_w]. For
|
164 |
+
downsampling, the ratio should be smaller than 1.0 (i.e., ratio
|
165 |
+
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
|
166 |
+
ratio > 1.0).
|
167 |
+
2) The order of output_size should be [out_h, out_w].
|
168 |
+
interp_mode (str): The mode of interpolation for resizing.
|
169 |
+
Default: 'bilinear'.
|
170 |
+
align_corners (bool): Whether align corners. Default: False.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
Tensor: Resized flow.
|
174 |
+
"""
|
175 |
+
_, _, flow_h, flow_w = flow.size()
|
176 |
+
if size_type == 'ratio':
|
177 |
+
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
|
178 |
+
elif size_type == 'shape':
|
179 |
+
output_h, output_w = sizes[0], sizes[1]
|
180 |
+
else:
|
181 |
+
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
|
182 |
+
|
183 |
+
input_flow = flow.clone()
|
184 |
+
ratio_h = output_h / flow_h
|
185 |
+
ratio_w = output_w / flow_w
|
186 |
+
input_flow[:, 0, :, :] *= ratio_w
|
187 |
+
input_flow[:, 1, :, :] *= ratio_h
|
188 |
+
resized_flow = F.interpolate(
|
189 |
+
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
|
190 |
+
return resized_flow
|
191 |
+
|
192 |
+
|
193 |
+
# TODO: may write a cpp file
|
194 |
+
def pixel_unshuffle(x, scale):
|
195 |
+
""" Pixel unshuffle.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
x (Tensor): Input feature with shape (b, c, hh, hw).
|
199 |
+
scale (int): Downsample ratio.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
Tensor: the pixel unshuffled feature.
|
203 |
+
"""
|
204 |
+
b, c, hh, hw = x.size()
|
205 |
+
out_channel = c * (scale**2)
|
206 |
+
assert hh % scale == 0 and hw % scale == 0
|
207 |
+
h = hh // scale
|
208 |
+
w = hw // scale
|
209 |
+
x_view = x.view(b, c, h, scale, w, scale)
|
210 |
+
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
|
211 |
+
|
212 |
+
|
213 |
+
class DCNv2Pack(ModulatedDeformConvPack):
|
214 |
+
"""Modulated deformable conv for deformable alignment.
|
215 |
+
|
216 |
+
Different from the official DCNv2Pack, which generates offsets and masks
|
217 |
+
from the preceding features, this DCNv2Pack takes another different
|
218 |
+
features to generate offsets and masks.
|
219 |
+
|
220 |
+
Ref:
|
221 |
+
Delving Deep into Deformable Alignment in Video Super-Resolution.
|
222 |
+
"""
|
223 |
+
|
224 |
+
def forward(self, x, feat):
|
225 |
+
out = self.conv_offset(feat)
|
226 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
227 |
+
offset = torch.cat((o1, o2), dim=1)
|
228 |
+
mask = torch.sigmoid(mask)
|
229 |
+
|
230 |
+
offset_absmean = torch.mean(torch.abs(offset))
|
231 |
+
if offset_absmean > 50:
|
232 |
+
logger = get_root_logger()
|
233 |
+
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
|
234 |
+
|
235 |
+
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
|
236 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
237 |
+
self.dilation, mask)
|
238 |
+
else:
|
239 |
+
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
240 |
+
self.dilation, self.groups, self.deformable_groups)
|
241 |
+
|
242 |
+
|
243 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
244 |
+
# From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
245 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
246 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
247 |
+
def norm_cdf(x):
|
248 |
+
# Computes standard normal cumulative distribution function
|
249 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
250 |
+
|
251 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
252 |
+
warnings.warn(
|
253 |
+
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
|
254 |
+
'The distribution of values may be incorrect.',
|
255 |
+
stacklevel=2)
|
256 |
+
|
257 |
+
with torch.no_grad():
|
258 |
+
# Values are generated by using a truncated uniform distribution and
|
259 |
+
# then using the inverse CDF for the normal distribution.
|
260 |
+
# Get upper and lower cdf values
|
261 |
+
low = norm_cdf((a - mean) / std)
|
262 |
+
up = norm_cdf((b - mean) / std)
|
263 |
+
|
264 |
+
# Uniformly fill tensor with values from [low, up], then translate to
|
265 |
+
# [2l-1, 2u-1].
|
266 |
+
tensor.uniform_(2 * low - 1, 2 * up - 1)
|
267 |
+
|
268 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
269 |
+
# standard normal
|
270 |
+
tensor.erfinv_()
|
271 |
+
|
272 |
+
# Transform to proper mean, std
|
273 |
+
tensor.mul_(std * math.sqrt(2.))
|
274 |
+
tensor.add_(mean)
|
275 |
+
|
276 |
+
# Clamp to ensure it's in the proper range
|
277 |
+
tensor.clamp_(min=a, max=b)
|
278 |
+
return tensor
|
279 |
+
|
280 |
+
|
281 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
282 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
283 |
+
normal distribution.
|
284 |
+
|
285 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
|
286 |
+
|
287 |
+
The values are effectively drawn from the
|
288 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
289 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
290 |
+
the bounds. The method used for generating the random values works
|
291 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
tensor: an n-dimensional `torch.Tensor`
|
295 |
+
mean: the mean of the normal distribution
|
296 |
+
std: the standard deviation of the normal distribution
|
297 |
+
a: the minimum cutoff value
|
298 |
+
b: the maximum cutoff value
|
299 |
+
|
300 |
+
Examples:
|
301 |
+
>>> w = torch.empty(3, 5)
|
302 |
+
>>> nn.init.trunc_normal_(w)
|
303 |
+
"""
|
304 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
305 |
+
|
306 |
+
|
307 |
+
# From PyTorch
|
308 |
+
def _ntuple(n):
|
309 |
+
|
310 |
+
def parse(x):
|
311 |
+
if isinstance(x, collections.abc.Iterable):
|
312 |
+
return x
|
313 |
+
return tuple(repeat(x, n))
|
314 |
+
|
315 |
+
return parse
|
316 |
+
|
317 |
+
|
318 |
+
to_1tuple = _ntuple(1)
|
319 |
+
to_2tuple = _ntuple(2)
|
320 |
+
to_3tuple = _ntuple(3)
|
321 |
+
to_4tuple = _ntuple(4)
|
322 |
+
to_ntuple = _ntuple
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
|
7 |
+
from .edvr_arch import PCDAlignment, TSAFusion
|
8 |
+
from .spynet_arch import SpyNet
|
9 |
+
|
10 |
+
|
11 |
+
@ARCH_REGISTRY.register()
|
12 |
+
class BasicVSR(nn.Module):
|
13 |
+
"""A recurrent network for video SR. Now only x4 is supported.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_feat (int): Number of channels. Default: 64.
|
17 |
+
num_block (int): Number of residual blocks for each branch. Default: 15
|
18 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_block=15, spynet_path=None):
|
22 |
+
super().__init__()
|
23 |
+
self.num_feat = num_feat
|
24 |
+
|
25 |
+
# alignment
|
26 |
+
self.spynet = SpyNet(spynet_path)
|
27 |
+
|
28 |
+
# propagation
|
29 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
30 |
+
self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
31 |
+
|
32 |
+
# reconstruction
|
33 |
+
self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
35 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
36 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
37 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
38 |
+
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
# activation functions
|
42 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
43 |
+
|
44 |
+
def get_flow(self, x):
|
45 |
+
b, n, c, h, w = x.size()
|
46 |
+
|
47 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
48 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
49 |
+
|
50 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
51 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
52 |
+
|
53 |
+
return flows_forward, flows_backward
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
"""Forward function of BasicVSR.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
|
60 |
+
"""
|
61 |
+
flows_forward, flows_backward = self.get_flow(x)
|
62 |
+
b, n, _, h, w = x.size()
|
63 |
+
|
64 |
+
# backward branch
|
65 |
+
out_l = []
|
66 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
67 |
+
for i in range(n - 1, -1, -1):
|
68 |
+
x_i = x[:, i, :, :, :]
|
69 |
+
if i < n - 1:
|
70 |
+
flow = flows_backward[:, i, :, :, :]
|
71 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
72 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
73 |
+
feat_prop = self.backward_trunk(feat_prop)
|
74 |
+
out_l.insert(0, feat_prop)
|
75 |
+
|
76 |
+
# forward branch
|
77 |
+
feat_prop = torch.zeros_like(feat_prop)
|
78 |
+
for i in range(0, n):
|
79 |
+
x_i = x[:, i, :, :, :]
|
80 |
+
if i > 0:
|
81 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
82 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
83 |
+
|
84 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
85 |
+
feat_prop = self.forward_trunk(feat_prop)
|
86 |
+
|
87 |
+
# upsample
|
88 |
+
out = torch.cat([out_l[i], feat_prop], dim=1)
|
89 |
+
out = self.lrelu(self.fusion(out))
|
90 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
91 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
92 |
+
out = self.lrelu(self.conv_hr(out))
|
93 |
+
out = self.conv_last(out)
|
94 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
95 |
+
out += base
|
96 |
+
out_l[i] = out
|
97 |
+
|
98 |
+
return torch.stack(out_l, dim=1)
|
99 |
+
|
100 |
+
|
101 |
+
class ConvResidualBlocks(nn.Module):
|
102 |
+
"""Conv and residual block used in BasicVSR.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
num_in_ch (int): Number of input channels. Default: 3.
|
106 |
+
num_out_ch (int): Number of output channels. Default: 64.
|
107 |
+
num_block (int): Number of residual blocks. Default: 15.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
|
111 |
+
super().__init__()
|
112 |
+
self.main = nn.Sequential(
|
113 |
+
nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
114 |
+
make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
|
115 |
+
|
116 |
+
def forward(self, fea):
|
117 |
+
return self.main(fea)
|
118 |
+
|
119 |
+
|
120 |
+
@ARCH_REGISTRY.register()
|
121 |
+
class IconVSR(nn.Module):
|
122 |
+
"""IconVSR, proposed also in the BasicVSR paper.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
num_feat (int): Number of channels. Default: 64.
|
126 |
+
num_block (int): Number of residual blocks for each branch. Default: 15.
|
127 |
+
keyframe_stride (int): Keyframe stride. Default: 5.
|
128 |
+
temporal_padding (int): Temporal padding. Default: 2.
|
129 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
130 |
+
edvr_path (str): Path to the pretrained EDVR model. Default: None.
|
131 |
+
"""
|
132 |
+
|
133 |
+
def __init__(self,
|
134 |
+
num_feat=64,
|
135 |
+
num_block=15,
|
136 |
+
keyframe_stride=5,
|
137 |
+
temporal_padding=2,
|
138 |
+
spynet_path=None,
|
139 |
+
edvr_path=None):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.num_feat = num_feat
|
143 |
+
self.temporal_padding = temporal_padding
|
144 |
+
self.keyframe_stride = keyframe_stride
|
145 |
+
|
146 |
+
# keyframe_branch
|
147 |
+
self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
|
148 |
+
# alignment
|
149 |
+
self.spynet = SpyNet(spynet_path)
|
150 |
+
|
151 |
+
# propagation
|
152 |
+
self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
153 |
+
self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
|
154 |
+
|
155 |
+
self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
|
156 |
+
self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
|
157 |
+
|
158 |
+
# reconstruction
|
159 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
|
160 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
|
161 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
162 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
163 |
+
|
164 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
165 |
+
|
166 |
+
# activation functions
|
167 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
168 |
+
|
169 |
+
def pad_spatial(self, x):
|
170 |
+
"""Apply padding spatially.
|
171 |
+
|
172 |
+
Since the PCD module in EDVR requires that the resolution is a multiple
|
173 |
+
of 4, we apply padding to the input LR images if their resolution is
|
174 |
+
not divisible by 4.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
x (Tensor): Input LR sequence with shape (n, t, c, h, w).
|
178 |
+
Returns:
|
179 |
+
Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
|
180 |
+
"""
|
181 |
+
n, t, c, h, w = x.size()
|
182 |
+
|
183 |
+
pad_h = (4 - h % 4) % 4
|
184 |
+
pad_w = (4 - w % 4) % 4
|
185 |
+
|
186 |
+
# padding
|
187 |
+
x = x.view(-1, c, h, w)
|
188 |
+
x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
|
189 |
+
|
190 |
+
return x.view(n, t, c, h + pad_h, w + pad_w)
|
191 |
+
|
192 |
+
def get_flow(self, x):
|
193 |
+
b, n, c, h, w = x.size()
|
194 |
+
|
195 |
+
x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
|
196 |
+
x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
|
197 |
+
|
198 |
+
flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
|
199 |
+
flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
|
200 |
+
|
201 |
+
return flows_forward, flows_backward
|
202 |
+
|
203 |
+
def get_keyframe_feature(self, x, keyframe_idx):
|
204 |
+
if self.temporal_padding == 2:
|
205 |
+
x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
|
206 |
+
elif self.temporal_padding == 3:
|
207 |
+
x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
|
208 |
+
x = torch.cat(x, dim=1)
|
209 |
+
|
210 |
+
num_frames = 2 * self.temporal_padding + 1
|
211 |
+
feats_keyframe = {}
|
212 |
+
for i in keyframe_idx:
|
213 |
+
feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
|
214 |
+
return feats_keyframe
|
215 |
+
|
216 |
+
def forward(self, x):
|
217 |
+
b, n, _, h_input, w_input = x.size()
|
218 |
+
|
219 |
+
x = self.pad_spatial(x)
|
220 |
+
h, w = x.shape[3:]
|
221 |
+
|
222 |
+
keyframe_idx = list(range(0, n, self.keyframe_stride))
|
223 |
+
if keyframe_idx[-1] != n - 1:
|
224 |
+
keyframe_idx.append(n - 1) # last frame is a keyframe
|
225 |
+
|
226 |
+
# compute flow and keyframe features
|
227 |
+
flows_forward, flows_backward = self.get_flow(x)
|
228 |
+
feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
|
229 |
+
|
230 |
+
# backward branch
|
231 |
+
out_l = []
|
232 |
+
feat_prop = x.new_zeros(b, self.num_feat, h, w)
|
233 |
+
for i in range(n - 1, -1, -1):
|
234 |
+
x_i = x[:, i, :, :, :]
|
235 |
+
if i < n - 1:
|
236 |
+
flow = flows_backward[:, i, :, :, :]
|
237 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
238 |
+
if i in keyframe_idx:
|
239 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
240 |
+
feat_prop = self.backward_fusion(feat_prop)
|
241 |
+
feat_prop = torch.cat([x_i, feat_prop], dim=1)
|
242 |
+
feat_prop = self.backward_trunk(feat_prop)
|
243 |
+
out_l.insert(0, feat_prop)
|
244 |
+
|
245 |
+
# forward branch
|
246 |
+
feat_prop = torch.zeros_like(feat_prop)
|
247 |
+
for i in range(0, n):
|
248 |
+
x_i = x[:, i, :, :, :]
|
249 |
+
if i > 0:
|
250 |
+
flow = flows_forward[:, i - 1, :, :, :]
|
251 |
+
feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
|
252 |
+
if i in keyframe_idx:
|
253 |
+
feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
|
254 |
+
feat_prop = self.forward_fusion(feat_prop)
|
255 |
+
|
256 |
+
feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
|
257 |
+
feat_prop = self.forward_trunk(feat_prop)
|
258 |
+
|
259 |
+
# upsample
|
260 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
|
261 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
262 |
+
out = self.lrelu(self.conv_hr(out))
|
263 |
+
out = self.conv_last(out)
|
264 |
+
base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
|
265 |
+
out += base
|
266 |
+
out_l[i] = out
|
267 |
+
|
268 |
+
return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
|
269 |
+
|
270 |
+
|
271 |
+
class EDVRFeatureExtractor(nn.Module):
|
272 |
+
"""EDVR feature extractor used in IconVSR.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
num_input_frame (int): Number of input frames.
|
276 |
+
num_feat (int): Number of feature channels
|
277 |
+
load_path (str): Path to the pretrained weights of EDVR. Default: None.
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self, num_input_frame, num_feat, load_path):
|
281 |
+
|
282 |
+
super(EDVRFeatureExtractor, self).__init__()
|
283 |
+
|
284 |
+
self.center_frame_idx = num_input_frame // 2
|
285 |
+
|
286 |
+
# extract pyramid features
|
287 |
+
self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
|
288 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
|
289 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
290 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
291 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
292 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
293 |
+
|
294 |
+
# pcd and tsa module
|
295 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
|
296 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
|
297 |
+
|
298 |
+
# activation function
|
299 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
300 |
+
|
301 |
+
if load_path:
|
302 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
303 |
+
|
304 |
+
def forward(self, x):
|
305 |
+
b, n, c, h, w = x.size()
|
306 |
+
|
307 |
+
# extract features for each frame
|
308 |
+
# L1
|
309 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
310 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
311 |
+
# L2
|
312 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
313 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
314 |
+
# L3
|
315 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
316 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
317 |
+
|
318 |
+
feat_l1 = feat_l1.view(b, n, -1, h, w)
|
319 |
+
feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
|
320 |
+
feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
|
321 |
+
|
322 |
+
# PCD alignment
|
323 |
+
ref_feat_l = [ # reference feature list
|
324 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
325 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
326 |
+
]
|
327 |
+
aligned_feat = []
|
328 |
+
for i in range(n):
|
329 |
+
nbr_feat_l = [ # neighboring feature list
|
330 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
331 |
+
]
|
332 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
333 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
334 |
+
|
335 |
+
# TSA fusion
|
336 |
+
return self.fusion(aligned_feat)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
import warnings
|
6 |
+
|
7 |
+
from r_basicsr.archs.arch_util import flow_warp
|
8 |
+
from r_basicsr.archs.basicvsr_arch import ConvResidualBlocks
|
9 |
+
from r_basicsr.archs.spynet_arch import SpyNet
|
10 |
+
from r_basicsr.ops.dcn import ModulatedDeformConvPack
|
11 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
12 |
+
|
13 |
+
|
14 |
+
@ARCH_REGISTRY.register()
|
15 |
+
class BasicVSRPlusPlus(nn.Module):
|
16 |
+
"""BasicVSR++ network structure.
|
17 |
+
Support either x4 upsampling or same size output. Since DCN is used in this
|
18 |
+
model, it can only be used with CUDA enabled. If CUDA is not enabled,
|
19 |
+
feature alignment will be skipped. Besides, we adopt the official DCN
|
20 |
+
implementation and the version of torch need to be higher than 1.9.
|
21 |
+
Paper:
|
22 |
+
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation
|
23 |
+
and Alignment
|
24 |
+
Args:
|
25 |
+
mid_channels (int, optional): Channel number of the intermediate
|
26 |
+
features. Default: 64.
|
27 |
+
num_blocks (int, optional): The number of residual blocks in each
|
28 |
+
propagation branch. Default: 7.
|
29 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
30 |
+
residue (Eq. 6 in paper). Default: 10.
|
31 |
+
is_low_res_input (bool, optional): Whether the input is low-resolution
|
32 |
+
or not. If False, the output resolution is equal to the input
|
33 |
+
resolution. Default: True.
|
34 |
+
spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
|
35 |
+
cpu_cache_length (int, optional): When the length of sequence is larger
|
36 |
+
than this value, the intermediate features are sent to CPU. This
|
37 |
+
saves GPU memory, but slows down the inference speed. You can
|
38 |
+
increase this number if you have a GPU with large memory.
|
39 |
+
Default: 100.
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self,
|
43 |
+
mid_channels=64,
|
44 |
+
num_blocks=7,
|
45 |
+
max_residue_magnitude=10,
|
46 |
+
is_low_res_input=True,
|
47 |
+
spynet_path=None,
|
48 |
+
cpu_cache_length=100):
|
49 |
+
|
50 |
+
super().__init__()
|
51 |
+
self.mid_channels = mid_channels
|
52 |
+
self.is_low_res_input = is_low_res_input
|
53 |
+
self.cpu_cache_length = cpu_cache_length
|
54 |
+
|
55 |
+
# optical flow
|
56 |
+
self.spynet = SpyNet(spynet_path)
|
57 |
+
|
58 |
+
# feature extraction module
|
59 |
+
if is_low_res_input:
|
60 |
+
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
|
61 |
+
else:
|
62 |
+
self.feat_extract = nn.Sequential(
|
63 |
+
nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
64 |
+
nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
65 |
+
ConvResidualBlocks(mid_channels, mid_channels, 5))
|
66 |
+
|
67 |
+
# propagation branches
|
68 |
+
self.deform_align = nn.ModuleDict()
|
69 |
+
self.backbone = nn.ModuleDict()
|
70 |
+
modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
|
71 |
+
for i, module in enumerate(modules):
|
72 |
+
if torch.cuda.is_available():
|
73 |
+
self.deform_align[module] = SecondOrderDeformableAlignment(
|
74 |
+
2 * mid_channels,
|
75 |
+
mid_channels,
|
76 |
+
3,
|
77 |
+
padding=1,
|
78 |
+
deformable_groups=16,
|
79 |
+
max_residue_magnitude=max_residue_magnitude)
|
80 |
+
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
|
81 |
+
|
82 |
+
# upsampling module
|
83 |
+
self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
|
84 |
+
|
85 |
+
self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
|
86 |
+
self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
|
87 |
+
|
88 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
89 |
+
|
90 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
91 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
92 |
+
self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
|
93 |
+
|
94 |
+
# activation function
|
95 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
96 |
+
|
97 |
+
# check if the sequence is augmented by flipping
|
98 |
+
self.is_mirror_extended = False
|
99 |
+
|
100 |
+
if len(self.deform_align) > 0:
|
101 |
+
self.is_with_alignment = True
|
102 |
+
else:
|
103 |
+
self.is_with_alignment = False
|
104 |
+
warnings.warn('Deformable alignment module is not added. '
|
105 |
+
'Probably your CUDA is not configured correctly. DCN can only '
|
106 |
+
'be used with CUDA enabled. Alignment is skipped now.')
|
107 |
+
|
108 |
+
def check_if_mirror_extended(self, lqs):
|
109 |
+
"""Check whether the input is a mirror-extended sequence.
|
110 |
+
If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
|
111 |
+
(t-1-i)-th frame.
|
112 |
+
Args:
|
113 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
114 |
+
shape (n, t, c, h, w).
|
115 |
+
"""
|
116 |
+
|
117 |
+
if lqs.size(1) % 2 == 0:
|
118 |
+
lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
|
119 |
+
if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
|
120 |
+
self.is_mirror_extended = True
|
121 |
+
|
122 |
+
def compute_flow(self, lqs):
|
123 |
+
"""Compute optical flow using SPyNet for feature alignment.
|
124 |
+
Note that if the input is an mirror-extended sequence, 'flows_forward'
|
125 |
+
is not needed, since it is equal to 'flows_backward.flip(1)'.
|
126 |
+
Args:
|
127 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
128 |
+
shape (n, t, c, h, w).
|
129 |
+
Return:
|
130 |
+
tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
|
131 |
+
flows used for forward-time propagation (current to previous).
|
132 |
+
'flows_backward' corresponds to the flows used for
|
133 |
+
backward-time propagation (current to next).
|
134 |
+
"""
|
135 |
+
|
136 |
+
n, t, c, h, w = lqs.size()
|
137 |
+
lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
|
138 |
+
lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
|
139 |
+
|
140 |
+
flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
|
141 |
+
|
142 |
+
if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
|
143 |
+
flows_forward = flows_backward.flip(1)
|
144 |
+
else:
|
145 |
+
flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
|
146 |
+
|
147 |
+
if self.cpu_cache:
|
148 |
+
flows_backward = flows_backward.cpu()
|
149 |
+
flows_forward = flows_forward.cpu()
|
150 |
+
|
151 |
+
return flows_forward, flows_backward
|
152 |
+
|
153 |
+
def propagate(self, feats, flows, module_name):
|
154 |
+
"""Propagate the latent features throughout the sequence.
|
155 |
+
Args:
|
156 |
+
feats dict(list[tensor]): Features from previous branches. Each
|
157 |
+
component is a list of tensors with shape (n, c, h, w).
|
158 |
+
flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
|
159 |
+
module_name (str): The name of the propgation branches. Can either
|
160 |
+
be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
|
161 |
+
Return:
|
162 |
+
dict(list[tensor]): A dictionary containing all the propagated
|
163 |
+
features. Each key in the dictionary corresponds to a
|
164 |
+
propagation branch, which is represented by a list of tensors.
|
165 |
+
"""
|
166 |
+
|
167 |
+
n, t, _, h, w = flows.size()
|
168 |
+
|
169 |
+
frame_idx = range(0, t + 1)
|
170 |
+
flow_idx = range(-1, t)
|
171 |
+
mapping_idx = list(range(0, len(feats['spatial'])))
|
172 |
+
mapping_idx += mapping_idx[::-1]
|
173 |
+
|
174 |
+
if 'backward' in module_name:
|
175 |
+
frame_idx = frame_idx[::-1]
|
176 |
+
flow_idx = frame_idx
|
177 |
+
|
178 |
+
feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
|
179 |
+
for i, idx in enumerate(frame_idx):
|
180 |
+
feat_current = feats['spatial'][mapping_idx[idx]]
|
181 |
+
if self.cpu_cache:
|
182 |
+
feat_current = feat_current.cuda()
|
183 |
+
feat_prop = feat_prop.cuda()
|
184 |
+
# second-order deformable alignment
|
185 |
+
if i > 0 and self.is_with_alignment:
|
186 |
+
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
187 |
+
if self.cpu_cache:
|
188 |
+
flow_n1 = flow_n1.cuda()
|
189 |
+
|
190 |
+
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
191 |
+
|
192 |
+
# initialize second-order features
|
193 |
+
feat_n2 = torch.zeros_like(feat_prop)
|
194 |
+
flow_n2 = torch.zeros_like(flow_n1)
|
195 |
+
cond_n2 = torch.zeros_like(cond_n1)
|
196 |
+
|
197 |
+
if i > 1: # second-order features
|
198 |
+
feat_n2 = feats[module_name][-2]
|
199 |
+
if self.cpu_cache:
|
200 |
+
feat_n2 = feat_n2.cuda()
|
201 |
+
|
202 |
+
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
203 |
+
if self.cpu_cache:
|
204 |
+
flow_n2 = flow_n2.cuda()
|
205 |
+
|
206 |
+
flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
|
207 |
+
cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
|
208 |
+
|
209 |
+
# flow-guided deformable convolution
|
210 |
+
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
211 |
+
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
212 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
|
213 |
+
|
214 |
+
# concatenate and residual blocks
|
215 |
+
feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
|
216 |
+
if self.cpu_cache:
|
217 |
+
feat = [f.cuda() for f in feat]
|
218 |
+
|
219 |
+
feat = torch.cat(feat, dim=1)
|
220 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
221 |
+
feats[module_name].append(feat_prop)
|
222 |
+
|
223 |
+
if self.cpu_cache:
|
224 |
+
feats[module_name][-1] = feats[module_name][-1].cpu()
|
225 |
+
torch.cuda.empty_cache()
|
226 |
+
|
227 |
+
if 'backward' in module_name:
|
228 |
+
feats[module_name] = feats[module_name][::-1]
|
229 |
+
|
230 |
+
return feats
|
231 |
+
|
232 |
+
def upsample(self, lqs, feats):
|
233 |
+
"""Compute the output image given the features.
|
234 |
+
Args:
|
235 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
236 |
+
shape (n, t, c, h, w).
|
237 |
+
feats (dict): The features from the propgation branches.
|
238 |
+
Returns:
|
239 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
240 |
+
"""
|
241 |
+
|
242 |
+
outputs = []
|
243 |
+
num_outputs = len(feats['spatial'])
|
244 |
+
|
245 |
+
mapping_idx = list(range(0, num_outputs))
|
246 |
+
mapping_idx += mapping_idx[::-1]
|
247 |
+
|
248 |
+
for i in range(0, lqs.size(1)):
|
249 |
+
hr = [feats[k].pop(0) for k in feats if k != 'spatial']
|
250 |
+
hr.insert(0, feats['spatial'][mapping_idx[i]])
|
251 |
+
hr = torch.cat(hr, dim=1)
|
252 |
+
if self.cpu_cache:
|
253 |
+
hr = hr.cuda()
|
254 |
+
|
255 |
+
hr = self.reconstruction(hr)
|
256 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
|
257 |
+
hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
|
258 |
+
hr = self.lrelu(self.conv_hr(hr))
|
259 |
+
hr = self.conv_last(hr)
|
260 |
+
if self.is_low_res_input:
|
261 |
+
hr += self.img_upsample(lqs[:, i, :, :, :])
|
262 |
+
else:
|
263 |
+
hr += lqs[:, i, :, :, :]
|
264 |
+
|
265 |
+
if self.cpu_cache:
|
266 |
+
hr = hr.cpu()
|
267 |
+
torch.cuda.empty_cache()
|
268 |
+
|
269 |
+
outputs.append(hr)
|
270 |
+
|
271 |
+
return torch.stack(outputs, dim=1)
|
272 |
+
|
273 |
+
def forward(self, lqs):
|
274 |
+
"""Forward function for BasicVSR++.
|
275 |
+
Args:
|
276 |
+
lqs (tensor): Input low quality (LQ) sequence with
|
277 |
+
shape (n, t, c, h, w).
|
278 |
+
Returns:
|
279 |
+
Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
|
280 |
+
"""
|
281 |
+
|
282 |
+
n, t, c, h, w = lqs.size()
|
283 |
+
|
284 |
+
# whether to cache the features in CPU
|
285 |
+
self.cpu_cache = True if t > self.cpu_cache_length else False
|
286 |
+
|
287 |
+
if self.is_low_res_input:
|
288 |
+
lqs_downsample = lqs.clone()
|
289 |
+
else:
|
290 |
+
lqs_downsample = F.interpolate(
|
291 |
+
lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
|
292 |
+
|
293 |
+
# check whether the input is an extended sequence
|
294 |
+
self.check_if_mirror_extended(lqs)
|
295 |
+
|
296 |
+
feats = {}
|
297 |
+
# compute spatial features
|
298 |
+
if self.cpu_cache:
|
299 |
+
feats['spatial'] = []
|
300 |
+
for i in range(0, t):
|
301 |
+
feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
|
302 |
+
feats['spatial'].append(feat)
|
303 |
+
torch.cuda.empty_cache()
|
304 |
+
else:
|
305 |
+
feats_ = self.feat_extract(lqs.view(-1, c, h, w))
|
306 |
+
h, w = feats_.shape[2:]
|
307 |
+
feats_ = feats_.view(n, t, -1, h, w)
|
308 |
+
feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
|
309 |
+
|
310 |
+
# compute optical flow using the low-res inputs
|
311 |
+
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
|
312 |
+
'The height and width of low-res inputs must be at least 64, '
|
313 |
+
f'but got {h} and {w}.')
|
314 |
+
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
|
315 |
+
|
316 |
+
# feature propgation
|
317 |
+
for iter_ in [1, 2]:
|
318 |
+
for direction in ['backward', 'forward']:
|
319 |
+
module = f'{direction}_{iter_}'
|
320 |
+
|
321 |
+
feats[module] = []
|
322 |
+
|
323 |
+
if direction == 'backward':
|
324 |
+
flows = flows_backward
|
325 |
+
elif flows_forward is not None:
|
326 |
+
flows = flows_forward
|
327 |
+
else:
|
328 |
+
flows = flows_backward.flip(1)
|
329 |
+
|
330 |
+
feats = self.propagate(feats, flows, module)
|
331 |
+
if self.cpu_cache:
|
332 |
+
del flows
|
333 |
+
torch.cuda.empty_cache()
|
334 |
+
|
335 |
+
return self.upsample(lqs, feats)
|
336 |
+
|
337 |
+
|
338 |
+
class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
|
339 |
+
"""Second-order deformable alignment module.
|
340 |
+
Args:
|
341 |
+
in_channels (int): Same as nn.Conv2d.
|
342 |
+
out_channels (int): Same as nn.Conv2d.
|
343 |
+
kernel_size (int or tuple[int]): Same as nn.Conv2d.
|
344 |
+
stride (int or tuple[int]): Same as nn.Conv2d.
|
345 |
+
padding (int or tuple[int]): Same as nn.Conv2d.
|
346 |
+
dilation (int or tuple[int]): Same as nn.Conv2d.
|
347 |
+
groups (int): Same as nn.Conv2d.
|
348 |
+
bias (bool or str): If specified as `auto`, it will be decided by the
|
349 |
+
norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
|
350 |
+
False.
|
351 |
+
max_residue_magnitude (int): The maximum magnitude of the offset
|
352 |
+
residue (Eq. 6 in paper). Default: 10.
|
353 |
+
"""
|
354 |
+
|
355 |
+
def __init__(self, *args, **kwargs):
|
356 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
357 |
+
|
358 |
+
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
359 |
+
|
360 |
+
self.conv_offset = nn.Sequential(
|
361 |
+
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
362 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
363 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
364 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
365 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
366 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
367 |
+
nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
|
368 |
+
)
|
369 |
+
|
370 |
+
self.init_offset()
|
371 |
+
|
372 |
+
def init_offset(self):
|
373 |
+
|
374 |
+
def _constant_init(module, val, bias=0):
|
375 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
376 |
+
nn.init.constant_(module.weight, val)
|
377 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
378 |
+
nn.init.constant_(module.bias, bias)
|
379 |
+
|
380 |
+
_constant_init(self.conv_offset[-1], val=0, bias=0)
|
381 |
+
|
382 |
+
def forward(self, x, extra_feat, flow_1, flow_2):
|
383 |
+
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
384 |
+
out = self.conv_offset(extra_feat)
|
385 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
386 |
+
|
387 |
+
# offset
|
388 |
+
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
389 |
+
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
390 |
+
offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
|
391 |
+
offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
|
392 |
+
offset = torch.cat([offset_1, offset_2], dim=1)
|
393 |
+
|
394 |
+
# mask
|
395 |
+
mask = torch.sigmoid(mask)
|
396 |
+
|
397 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
|
398 |
+
self.dilation, mask)
|
399 |
+
|
400 |
+
|
401 |
+
# if __name__ == '__main__':
|
402 |
+
# spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
|
403 |
+
# model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
|
404 |
+
# input = torch.rand(1, 2, 3, 64, 64).cuda()
|
405 |
+
# output = model(input)
|
406 |
+
# print('===================')
|
407 |
+
# print(output.shape)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
|
9 |
+
from .vgg_arch import VGGFeatureExtractor
|
10 |
+
|
11 |
+
|
12 |
+
class SFTUpBlock(nn.Module):
|
13 |
+
"""Spatial feature transform (SFT) with upsampling block.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
in_channel (int): Number of input channels.
|
17 |
+
out_channel (int): Number of output channels.
|
18 |
+
kernel_size (int): Kernel size in convolutions. Default: 3.
|
19 |
+
padding (int): Padding in convolutions. Default: 1.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
|
23 |
+
super(SFTUpBlock, self).__init__()
|
24 |
+
self.conv1 = nn.Sequential(
|
25 |
+
Blur(in_channel),
|
26 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
|
27 |
+
nn.LeakyReLU(0.04, True),
|
28 |
+
# The official codes use two LeakyReLU here, so 0.04 for equivalent
|
29 |
+
)
|
30 |
+
self.convup = nn.Sequential(
|
31 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
|
32 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
|
33 |
+
nn.LeakyReLU(0.2, True),
|
34 |
+
)
|
35 |
+
|
36 |
+
# for SFT scale and shift
|
37 |
+
self.scale_block = nn.Sequential(
|
38 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
39 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
|
40 |
+
self.shift_block = nn.Sequential(
|
41 |
+
spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
42 |
+
spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
|
43 |
+
# The official codes use sigmoid for shift block, do not know why
|
44 |
+
|
45 |
+
def forward(self, x, updated_feat):
|
46 |
+
out = self.conv1(x)
|
47 |
+
# SFT
|
48 |
+
scale = self.scale_block(updated_feat)
|
49 |
+
shift = self.shift_block(updated_feat)
|
50 |
+
out = out * scale + shift
|
51 |
+
# upsample
|
52 |
+
out = self.convup(out)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
@ARCH_REGISTRY.register()
|
57 |
+
class DFDNet(nn.Module):
|
58 |
+
"""DFDNet: Deep Face Dictionary Network.
|
59 |
+
|
60 |
+
It only processes faces with 512x512 size.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
num_feat (int): Number of feature channels.
|
64 |
+
dict_path (str): Path to the facial component dictionary.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, num_feat, dict_path):
|
68 |
+
super().__init__()
|
69 |
+
self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
|
70 |
+
# part_sizes: [80, 80, 50, 110]
|
71 |
+
channel_sizes = [128, 256, 512, 512]
|
72 |
+
self.feature_sizes = np.array([256, 128, 64, 32])
|
73 |
+
self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
|
74 |
+
self.flag_dict_device = False
|
75 |
+
|
76 |
+
# dict
|
77 |
+
self.dict = torch.load(dict_path)
|
78 |
+
|
79 |
+
# vgg face extractor
|
80 |
+
self.vgg_extractor = VGGFeatureExtractor(
|
81 |
+
layer_name_list=self.vgg_layers,
|
82 |
+
vgg_type='vgg19',
|
83 |
+
use_input_norm=True,
|
84 |
+
range_norm=True,
|
85 |
+
requires_grad=False)
|
86 |
+
|
87 |
+
# attention block for fusing dictionary features and input features
|
88 |
+
self.attn_blocks = nn.ModuleDict()
|
89 |
+
for idx, feat_size in enumerate(self.feature_sizes):
|
90 |
+
for name in self.parts:
|
91 |
+
self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
|
92 |
+
|
93 |
+
# multi scale dilation block
|
94 |
+
self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
|
95 |
+
|
96 |
+
# upsampling and reconstruction
|
97 |
+
self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
|
98 |
+
self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
|
99 |
+
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
|
100 |
+
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
|
101 |
+
self.upsample4 = nn.Sequential(
|
102 |
+
spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
|
103 |
+
UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
|
104 |
+
|
105 |
+
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
|
106 |
+
"""swap the features from the dictionary."""
|
107 |
+
# get the original vgg features
|
108 |
+
part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
|
109 |
+
# resize original vgg features
|
110 |
+
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
|
111 |
+
# use adaptive instance normalization to adjust color and illuminations
|
112 |
+
dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
|
113 |
+
# get similarity scores
|
114 |
+
similarity_score = F.conv2d(part_resize_feat, dict_feat)
|
115 |
+
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
|
116 |
+
# select the most similar features in the dict (after norm)
|
117 |
+
select_idx = torch.argmax(similarity_score)
|
118 |
+
swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
|
119 |
+
# attention
|
120 |
+
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
|
121 |
+
attn_feat = attn * swap_feat
|
122 |
+
# update features
|
123 |
+
updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
|
124 |
+
return updated_feat
|
125 |
+
|
126 |
+
def put_dict_to_device(self, x):
|
127 |
+
if self.flag_dict_device is False:
|
128 |
+
for k, v in self.dict.items():
|
129 |
+
for kk, vv in v.items():
|
130 |
+
self.dict[k][kk] = vv.to(x)
|
131 |
+
self.flag_dict_device = True
|
132 |
+
|
133 |
+
def forward(self, x, part_locations):
|
134 |
+
"""
|
135 |
+
Now only support testing with batch size = 0.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
x (Tensor): Input faces with shape (b, c, 512, 512).
|
139 |
+
part_locations (list[Tensor]): Part locations.
|
140 |
+
"""
|
141 |
+
self.put_dict_to_device(x)
|
142 |
+
# extract vggface features
|
143 |
+
vgg_features = self.vgg_extractor(x)
|
144 |
+
# update vggface features using the dictionary for each part
|
145 |
+
updated_vgg_features = []
|
146 |
+
batch = 0 # only supports testing with batch size = 0
|
147 |
+
for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
|
148 |
+
dict_features = self.dict[f'{f_size}']
|
149 |
+
vgg_feat = vgg_features[vgg_layer]
|
150 |
+
updated_feat = vgg_feat.clone()
|
151 |
+
|
152 |
+
# swap features from dictionary
|
153 |
+
for part_idx, part_name in enumerate(self.parts):
|
154 |
+
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
|
155 |
+
updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
|
156 |
+
f_size)
|
157 |
+
|
158 |
+
updated_vgg_features.append(updated_feat)
|
159 |
+
|
160 |
+
vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
|
161 |
+
# use updated vgg features to modulate the upsampled features with
|
162 |
+
# SFT (Spatial Feature Transform) scaling and shifting manner.
|
163 |
+
upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
|
164 |
+
upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
|
165 |
+
upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
|
166 |
+
upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
|
167 |
+
out = self.upsample4(upsampled_feat)
|
168 |
+
|
169 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Function
|
5 |
+
from torch.nn.utils.spectral_norm import spectral_norm
|
6 |
+
|
7 |
+
|
8 |
+
class BlurFunctionBackward(Function):
|
9 |
+
|
10 |
+
@staticmethod
|
11 |
+
def forward(ctx, grad_output, kernel, kernel_flip):
|
12 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
13 |
+
grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
|
14 |
+
return grad_input
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def backward(ctx, gradgrad_output):
|
18 |
+
kernel, _ = ctx.saved_tensors
|
19 |
+
grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
|
20 |
+
return grad_input, None, None
|
21 |
+
|
22 |
+
|
23 |
+
class BlurFunction(Function):
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def forward(ctx, x, kernel, kernel_flip):
|
27 |
+
ctx.save_for_backward(kernel, kernel_flip)
|
28 |
+
output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
|
29 |
+
return output
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def backward(ctx, grad_output):
|
33 |
+
kernel, kernel_flip = ctx.saved_tensors
|
34 |
+
grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
|
35 |
+
return grad_input, None, None
|
36 |
+
|
37 |
+
|
38 |
+
blur = BlurFunction.apply
|
39 |
+
|
40 |
+
|
41 |
+
class Blur(nn.Module):
|
42 |
+
|
43 |
+
def __init__(self, channel):
|
44 |
+
super().__init__()
|
45 |
+
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
|
46 |
+
kernel = kernel.view(1, 1, 3, 3)
|
47 |
+
kernel = kernel / kernel.sum()
|
48 |
+
kernel_flip = torch.flip(kernel, [2, 3])
|
49 |
+
|
50 |
+
self.kernel = kernel.repeat(channel, 1, 1, 1)
|
51 |
+
self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
|
55 |
+
|
56 |
+
|
57 |
+
def calc_mean_std(feat, eps=1e-5):
|
58 |
+
"""Calculate mean and std for adaptive_instance_normalization.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
feat (Tensor): 4D tensor.
|
62 |
+
eps (float): A small value added to the variance to avoid
|
63 |
+
divide-by-zero. Default: 1e-5.
|
64 |
+
"""
|
65 |
+
size = feat.size()
|
66 |
+
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
67 |
+
n, c = size[:2]
|
68 |
+
feat_var = feat.view(n, c, -1).var(dim=2) + eps
|
69 |
+
feat_std = feat_var.sqrt().view(n, c, 1, 1)
|
70 |
+
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
|
71 |
+
return feat_mean, feat_std
|
72 |
+
|
73 |
+
|
74 |
+
def adaptive_instance_normalization(content_feat, style_feat):
|
75 |
+
"""Adaptive instance normalization.
|
76 |
+
|
77 |
+
Adjust the reference features to have the similar color and illuminations
|
78 |
+
as those in the degradate features.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
content_feat (Tensor): The reference feature.
|
82 |
+
style_feat (Tensor): The degradate features.
|
83 |
+
"""
|
84 |
+
size = content_feat.size()
|
85 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
86 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
87 |
+
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
88 |
+
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
89 |
+
|
90 |
+
|
91 |
+
def AttentionBlock(in_channel):
|
92 |
+
return nn.Sequential(
|
93 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
|
94 |
+
spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
|
95 |
+
|
96 |
+
|
97 |
+
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
|
98 |
+
"""Conv block used in MSDilationBlock."""
|
99 |
+
|
100 |
+
return nn.Sequential(
|
101 |
+
spectral_norm(
|
102 |
+
nn.Conv2d(
|
103 |
+
in_channels,
|
104 |
+
out_channels,
|
105 |
+
kernel_size=kernel_size,
|
106 |
+
stride=stride,
|
107 |
+
dilation=dilation,
|
108 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
109 |
+
bias=bias)),
|
110 |
+
nn.LeakyReLU(0.2),
|
111 |
+
spectral_norm(
|
112 |
+
nn.Conv2d(
|
113 |
+
out_channels,
|
114 |
+
out_channels,
|
115 |
+
kernel_size=kernel_size,
|
116 |
+
stride=stride,
|
117 |
+
dilation=dilation,
|
118 |
+
padding=((kernel_size - 1) // 2) * dilation,
|
119 |
+
bias=bias)),
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
class MSDilationBlock(nn.Module):
|
124 |
+
"""Multi-scale dilation block."""
|
125 |
+
|
126 |
+
def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
|
127 |
+
super(MSDilationBlock, self).__init__()
|
128 |
+
|
129 |
+
self.conv_blocks = nn.ModuleList()
|
130 |
+
for i in range(4):
|
131 |
+
self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
|
132 |
+
self.conv_fusion = spectral_norm(
|
133 |
+
nn.Conv2d(
|
134 |
+
in_channels * 4,
|
135 |
+
in_channels,
|
136 |
+
kernel_size=kernel_size,
|
137 |
+
stride=1,
|
138 |
+
padding=(kernel_size - 1) // 2,
|
139 |
+
bias=bias))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
out = []
|
143 |
+
for i in range(4):
|
144 |
+
out.append(self.conv_blocks[i](x))
|
145 |
+
out = torch.cat(out, 1)
|
146 |
+
out = self.conv_fusion(out) + x
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
class UpResBlock(nn.Module):
|
151 |
+
|
152 |
+
def __init__(self, in_channel):
|
153 |
+
super(UpResBlock, self).__init__()
|
154 |
+
self.body = nn.Sequential(
|
155 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
156 |
+
nn.LeakyReLU(0.2, True),
|
157 |
+
nn.Conv2d(in_channel, in_channel, 3, 1, 1),
|
158 |
+
)
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
out = x + self.body(x)
|
162 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
from torch.nn.utils import spectral_norm
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class VGGStyleDiscriminator(nn.Module):
|
10 |
+
"""VGG style discriminator with input size 128 x 128 or 256 x 256.
|
11 |
+
|
12 |
+
It is used to train SRGAN, ESRGAN, and VideoGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_feat (int): Channel number of base intermediate features.Default: 64.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_in_ch, num_feat, input_size=128):
|
20 |
+
super(VGGStyleDiscriminator, self).__init__()
|
21 |
+
self.input_size = input_size
|
22 |
+
assert self.input_size == 128 or self.input_size == 256, (
|
23 |
+
f'input size must be 128 or 256, but received {input_size}')
|
24 |
+
|
25 |
+
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
|
26 |
+
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
|
27 |
+
self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
|
28 |
+
|
29 |
+
self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
|
30 |
+
self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
31 |
+
self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
|
32 |
+
self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
|
33 |
+
|
34 |
+
self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
|
35 |
+
self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
36 |
+
self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
|
37 |
+
self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
|
38 |
+
|
39 |
+
self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
|
40 |
+
self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
41 |
+
self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
42 |
+
self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
43 |
+
|
44 |
+
self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
45 |
+
self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
46 |
+
self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
47 |
+
self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
48 |
+
|
49 |
+
if self.input_size == 256:
|
50 |
+
self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
|
51 |
+
self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
52 |
+
self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
|
53 |
+
self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
|
54 |
+
|
55 |
+
self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
|
56 |
+
self.linear2 = nn.Linear(100, 1)
|
57 |
+
|
58 |
+
# activation function
|
59 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
|
63 |
+
|
64 |
+
feat = self.lrelu(self.conv0_0(x))
|
65 |
+
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
|
66 |
+
|
67 |
+
feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
|
68 |
+
feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
|
69 |
+
|
70 |
+
feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
|
71 |
+
feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
|
72 |
+
|
73 |
+
feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
|
74 |
+
feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
|
75 |
+
|
76 |
+
feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
|
77 |
+
feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
|
78 |
+
|
79 |
+
if self.input_size == 256:
|
80 |
+
feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
|
81 |
+
feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
|
82 |
+
|
83 |
+
# spatial size: (4, 4)
|
84 |
+
feat = feat.view(feat.size(0), -1)
|
85 |
+
feat = self.lrelu(self.linear1(feat))
|
86 |
+
out = self.linear2(feat)
|
87 |
+
return out
|
88 |
+
|
89 |
+
|
90 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
91 |
+
class UNetDiscriminatorSN(nn.Module):
|
92 |
+
"""Defines a U-Net discriminator with spectral normalization (SN)
|
93 |
+
|
94 |
+
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
95 |
+
|
96 |
+
Arg:
|
97 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
98 |
+
num_feat (int): Channel number of base intermediate features. Default: 64.
|
99 |
+
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
|
103 |
+
super(UNetDiscriminatorSN, self).__init__()
|
104 |
+
self.skip_connection = skip_connection
|
105 |
+
norm = spectral_norm
|
106 |
+
# the first convolution
|
107 |
+
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
|
108 |
+
# downsample
|
109 |
+
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
|
110 |
+
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
|
111 |
+
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
|
112 |
+
# upsample
|
113 |
+
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
|
114 |
+
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
|
115 |
+
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
|
116 |
+
# extra convolutions
|
117 |
+
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
118 |
+
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
|
119 |
+
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
# downsample
|
123 |
+
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
|
124 |
+
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
|
125 |
+
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
|
126 |
+
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
|
127 |
+
|
128 |
+
# upsample
|
129 |
+
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
|
130 |
+
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
|
131 |
+
|
132 |
+
if self.skip_connection:
|
133 |
+
x4 = x4 + x2
|
134 |
+
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
|
135 |
+
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
|
136 |
+
|
137 |
+
if self.skip_connection:
|
138 |
+
x5 = x5 + x1
|
139 |
+
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
|
140 |
+
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
|
141 |
+
|
142 |
+
if self.skip_connection:
|
143 |
+
x6 = x6 + x0
|
144 |
+
|
145 |
+
# extra convolutions
|
146 |
+
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
|
147 |
+
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
|
148 |
+
out = self.conv9(out)
|
149 |
+
|
150 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
|
8 |
+
|
9 |
+
class DenseBlocksTemporalReduce(nn.Module):
|
10 |
+
"""A concatenation of 3 dense blocks with reduction in temporal dimension.
|
11 |
+
|
12 |
+
Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
16 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32
|
17 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
18 |
+
Set to false if you want to train from scratch. Default: False.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
|
22 |
+
super(DenseBlocksTemporalReduce, self).__init__()
|
23 |
+
if adapt_official_weights:
|
24 |
+
eps = 1e-3
|
25 |
+
momentum = 1e-3
|
26 |
+
else: # pytorch default values
|
27 |
+
eps = 1e-05
|
28 |
+
momentum = 0.1
|
29 |
+
|
30 |
+
self.temporal_reduce1 = nn.Sequential(
|
31 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
32 |
+
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
|
33 |
+
nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
34 |
+
nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
35 |
+
|
36 |
+
self.temporal_reduce2 = nn.Sequential(
|
37 |
+
nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
38 |
+
nn.Conv3d(
|
39 |
+
num_feat + num_grow_ch,
|
40 |
+
num_feat + num_grow_ch, (1, 1, 1),
|
41 |
+
stride=(1, 1, 1),
|
42 |
+
padding=(0, 0, 0),
|
43 |
+
bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
44 |
+
nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
45 |
+
|
46 |
+
self.temporal_reduce3 = nn.Sequential(
|
47 |
+
nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
48 |
+
nn.Conv3d(
|
49 |
+
num_feat + 2 * num_grow_ch,
|
50 |
+
num_feat + 2 * num_grow_ch, (1, 1, 1),
|
51 |
+
stride=(1, 1, 1),
|
52 |
+
padding=(0, 0, 0),
|
53 |
+
bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
|
54 |
+
nn.ReLU(inplace=True),
|
55 |
+
nn.Conv3d(
|
56 |
+
num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
|
65 |
+
"""
|
66 |
+
x1 = self.temporal_reduce1(x)
|
67 |
+
x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
|
68 |
+
|
69 |
+
x2 = self.temporal_reduce2(x1)
|
70 |
+
x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
|
71 |
+
|
72 |
+
x3 = self.temporal_reduce3(x2)
|
73 |
+
x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
|
74 |
+
|
75 |
+
return x3
|
76 |
+
|
77 |
+
|
78 |
+
class DenseBlocks(nn.Module):
|
79 |
+
""" A concatenation of N dense blocks.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
num_feat (int): Number of channels in the blocks. Default: 64.
|
83 |
+
num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
|
84 |
+
num_block (int): Number of dense blocks. The values are:
|
85 |
+
DUF-S (16 layers): 3
|
86 |
+
DUF-M (18 layers): 9
|
87 |
+
DUF-L (52 layers): 21
|
88 |
+
adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
|
89 |
+
Set to false if you want to train from scratch. Default: False.
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
|
93 |
+
super(DenseBlocks, self).__init__()
|
94 |
+
if adapt_official_weights:
|
95 |
+
eps = 1e-3
|
96 |
+
momentum = 1e-3
|
97 |
+
else: # pytorch default values
|
98 |
+
eps = 1e-05
|
99 |
+
momentum = 0.1
|
100 |
+
|
101 |
+
self.dense_blocks = nn.ModuleList()
|
102 |
+
for i in range(0, num_block):
|
103 |
+
self.dense_blocks.append(
|
104 |
+
nn.Sequential(
|
105 |
+
nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
|
106 |
+
nn.Conv3d(
|
107 |
+
num_feat + i * num_grow_ch,
|
108 |
+
num_feat + i * num_grow_ch, (1, 1, 1),
|
109 |
+
stride=(1, 1, 1),
|
110 |
+
padding=(0, 0, 0),
|
111 |
+
bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
|
112 |
+
nn.ReLU(inplace=True),
|
113 |
+
nn.Conv3d(
|
114 |
+
num_feat + i * num_grow_ch,
|
115 |
+
num_grow_ch, (3, 3, 3),
|
116 |
+
stride=(1, 1, 1),
|
117 |
+
padding=(1, 1, 1),
|
118 |
+
bias=True)))
|
119 |
+
|
120 |
+
def forward(self, x):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
|
127 |
+
"""
|
128 |
+
for i in range(0, len(self.dense_blocks)):
|
129 |
+
y = self.dense_blocks[i](x)
|
130 |
+
x = torch.cat((x, y), 1)
|
131 |
+
return x
|
132 |
+
|
133 |
+
|
134 |
+
class DynamicUpsamplingFilter(nn.Module):
|
135 |
+
"""Dynamic upsampling filter used in DUF.
|
136 |
+
|
137 |
+
Ref: https://github.com/yhjo09/VSR-DUF.
|
138 |
+
It only supports input with 3 channels. And it applies the same filters to 3 channels.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
|
142 |
+
"""
|
143 |
+
|
144 |
+
def __init__(self, filter_size=(5, 5)):
|
145 |
+
super(DynamicUpsamplingFilter, self).__init__()
|
146 |
+
if not isinstance(filter_size, tuple):
|
147 |
+
raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
|
148 |
+
if len(filter_size) != 2:
|
149 |
+
raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
|
150 |
+
# generate a local expansion filter, similar to im2col
|
151 |
+
self.filter_size = filter_size
|
152 |
+
filter_prod = np.prod(filter_size)
|
153 |
+
expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
|
154 |
+
self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
|
155 |
+
|
156 |
+
def forward(self, x, filters):
|
157 |
+
"""Forward function for DynamicUpsamplingFilter.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
|
161 |
+
filters (Tensor): Generated dynamic filters.
|
162 |
+
The shape is (n, filter_prod, upsampling_square, h, w).
|
163 |
+
filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
|
164 |
+
upsampling_square: similar to pixel shuffle,
|
165 |
+
upsampling_square = upsampling * upsampling
|
166 |
+
e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
|
170 |
+
"""
|
171 |
+
n, filter_prod, upsampling_square, h, w = filters.size()
|
172 |
+
kh, kw = self.filter_size
|
173 |
+
expanded_input = F.conv2d(
|
174 |
+
x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
|
175 |
+
expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
|
176 |
+
2) # (n, h, w, 3, filter_prod)
|
177 |
+
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
|
178 |
+
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
|
179 |
+
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
|
180 |
+
|
181 |
+
|
182 |
+
@ARCH_REGISTRY.register()
|
183 |
+
class DUF(nn.Module):
|
184 |
+
"""Network architecture for DUF
|
185 |
+
|
186 |
+
Paper: Jo et.al. Deep Video Super-Resolution Network Using Dynamic
|
187 |
+
Upsampling Filters Without Explicit Motion Compensation, CVPR, 2018
|
188 |
+
Code reference:
|
189 |
+
https://github.com/yhjo09/VSR-DUF
|
190 |
+
For all the models below, 'adapt_official_weights' is only necessary when
|
191 |
+
loading the weights converted from the official TensorFlow weights.
|
192 |
+
Please set it to False if you are training the model from scratch.
|
193 |
+
|
194 |
+
There are three models with different model size: DUF16Layers, DUF28Layers,
|
195 |
+
and DUF52Layers. This class is the base class for these models.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
scale (int): The upsampling factor. Default: 4.
|
199 |
+
num_layer (int): The number of layers. Default: 52.
|
200 |
+
adapt_official_weights_weights (bool): Whether to adapt the weights
|
201 |
+
translated from the official implementation. Set to false if you
|
202 |
+
want to train from scratch. Default: False.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
|
206 |
+
super(DUF, self).__init__()
|
207 |
+
self.scale = scale
|
208 |
+
if adapt_official_weights:
|
209 |
+
eps = 1e-3
|
210 |
+
momentum = 1e-3
|
211 |
+
else: # pytorch default values
|
212 |
+
eps = 1e-05
|
213 |
+
momentum = 0.1
|
214 |
+
|
215 |
+
self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
216 |
+
self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
|
217 |
+
|
218 |
+
if num_layer == 16:
|
219 |
+
num_block = 3
|
220 |
+
num_grow_ch = 32
|
221 |
+
elif num_layer == 28:
|
222 |
+
num_block = 9
|
223 |
+
num_grow_ch = 16
|
224 |
+
elif num_layer == 52:
|
225 |
+
num_block = 21
|
226 |
+
num_grow_ch = 16
|
227 |
+
else:
|
228 |
+
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
|
229 |
+
|
230 |
+
self.dense_block1 = DenseBlocks(
|
231 |
+
num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
|
232 |
+
adapt_official_weights=adapt_official_weights) # T = 7
|
233 |
+
self.dense_block2 = DenseBlocksTemporalReduce(
|
234 |
+
64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
|
235 |
+
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
|
236 |
+
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
|
237 |
+
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
|
238 |
+
|
239 |
+
self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
240 |
+
self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
241 |
+
|
242 |
+
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
243 |
+
self.conv3d_f2 = nn.Conv3d(
|
244 |
+
512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
"""
|
248 |
+
Args:
|
249 |
+
x (Tensor): Input with shape (b, 7, c, h, w)
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Tensor: Output with shape (b, c, h * scale, w * scale)
|
253 |
+
"""
|
254 |
+
num_batches, num_imgs, _, h, w = x.size()
|
255 |
+
|
256 |
+
x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
|
257 |
+
x_center = x[:, :, num_imgs // 2, :, :]
|
258 |
+
|
259 |
+
x = self.conv3d1(x)
|
260 |
+
x = self.dense_block1(x)
|
261 |
+
x = self.dense_block2(x)
|
262 |
+
x = F.relu(self.bn3d2(x), inplace=True)
|
263 |
+
x = F.relu(self.conv3d2(x), inplace=True)
|
264 |
+
|
265 |
+
# residual image
|
266 |
+
res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
|
267 |
+
|
268 |
+
# filter
|
269 |
+
filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
|
270 |
+
filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
|
271 |
+
|
272 |
+
# dynamic filter
|
273 |
+
out = self.dynamic_filter(x_center, filter_)
|
274 |
+
out += res.squeeze_(2)
|
275 |
+
out = F.pixel_shuffle(out, self.scale)
|
276 |
+
|
277 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
class SeqConv3x3(nn.Module):
|
9 |
+
"""The re-parameterizable block used in the ECBSR architecture.
|
10 |
+
|
11 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
12 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
13 |
+
|
14 |
+
Args:
|
15 |
+
seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
|
16 |
+
in_channels (int): Channel number of input.
|
17 |
+
out_channels (int): Channel number of output.
|
18 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
|
22 |
+
super(SeqConv3x3, self).__init__()
|
23 |
+
self.seq_type = seq_type
|
24 |
+
self.in_channels = in_channels
|
25 |
+
self.out_channels = out_channels
|
26 |
+
|
27 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
28 |
+
self.mid_planes = int(out_channels * depth_multiplier)
|
29 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
|
30 |
+
self.k0 = conv0.weight
|
31 |
+
self.b0 = conv0.bias
|
32 |
+
|
33 |
+
conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
|
34 |
+
self.k1 = conv1.weight
|
35 |
+
self.b1 = conv1.bias
|
36 |
+
|
37 |
+
elif self.seq_type == 'conv1x1-sobelx':
|
38 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
39 |
+
self.k0 = conv0.weight
|
40 |
+
self.b0 = conv0.bias
|
41 |
+
|
42 |
+
# init scale and bias
|
43 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
44 |
+
self.scale = nn.Parameter(scale)
|
45 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
46 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
47 |
+
self.bias = nn.Parameter(bias)
|
48 |
+
# init mask
|
49 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
50 |
+
for i in range(self.out_channels):
|
51 |
+
self.mask[i, 0, 0, 0] = 1.0
|
52 |
+
self.mask[i, 0, 1, 0] = 2.0
|
53 |
+
self.mask[i, 0, 2, 0] = 1.0
|
54 |
+
self.mask[i, 0, 0, 2] = -1.0
|
55 |
+
self.mask[i, 0, 1, 2] = -2.0
|
56 |
+
self.mask[i, 0, 2, 2] = -1.0
|
57 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
58 |
+
|
59 |
+
elif self.seq_type == 'conv1x1-sobely':
|
60 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
61 |
+
self.k0 = conv0.weight
|
62 |
+
self.b0 = conv0.bias
|
63 |
+
|
64 |
+
# init scale and bias
|
65 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
66 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
67 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
68 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
69 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
70 |
+
# init mask
|
71 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
72 |
+
for i in range(self.out_channels):
|
73 |
+
self.mask[i, 0, 0, 0] = 1.0
|
74 |
+
self.mask[i, 0, 0, 1] = 2.0
|
75 |
+
self.mask[i, 0, 0, 2] = 1.0
|
76 |
+
self.mask[i, 0, 2, 0] = -1.0
|
77 |
+
self.mask[i, 0, 2, 1] = -2.0
|
78 |
+
self.mask[i, 0, 2, 2] = -1.0
|
79 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
80 |
+
|
81 |
+
elif self.seq_type == 'conv1x1-laplacian':
|
82 |
+
conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
|
83 |
+
self.k0 = conv0.weight
|
84 |
+
self.b0 = conv0.bias
|
85 |
+
|
86 |
+
# init scale and bias
|
87 |
+
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
|
88 |
+
self.scale = nn.Parameter(torch.FloatTensor(scale))
|
89 |
+
bias = torch.randn(self.out_channels) * 1e-3
|
90 |
+
bias = torch.reshape(bias, (self.out_channels, ))
|
91 |
+
self.bias = nn.Parameter(torch.FloatTensor(bias))
|
92 |
+
# init mask
|
93 |
+
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
|
94 |
+
for i in range(self.out_channels):
|
95 |
+
self.mask[i, 0, 0, 1] = 1.0
|
96 |
+
self.mask[i, 0, 1, 0] = 1.0
|
97 |
+
self.mask[i, 0, 1, 2] = 1.0
|
98 |
+
self.mask[i, 0, 2, 1] = 1.0
|
99 |
+
self.mask[i, 0, 1, 1] = -4.0
|
100 |
+
self.mask = nn.Parameter(data=self.mask, requires_grad=False)
|
101 |
+
else:
|
102 |
+
raise ValueError('The type of seqconv is not supported!')
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
106 |
+
# conv-1x1
|
107 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
108 |
+
# explicitly padding with bias
|
109 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
110 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
111 |
+
y0[:, :, 0:1, :] = b0_pad
|
112 |
+
y0[:, :, -1:, :] = b0_pad
|
113 |
+
y0[:, :, :, 0:1] = b0_pad
|
114 |
+
y0[:, :, :, -1:] = b0_pad
|
115 |
+
# conv-3x3
|
116 |
+
y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
|
117 |
+
else:
|
118 |
+
y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
|
119 |
+
# explicitly padding with bias
|
120 |
+
y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
|
121 |
+
b0_pad = self.b0.view(1, -1, 1, 1)
|
122 |
+
y0[:, :, 0:1, :] = b0_pad
|
123 |
+
y0[:, :, -1:, :] = b0_pad
|
124 |
+
y0[:, :, :, 0:1] = b0_pad
|
125 |
+
y0[:, :, :, -1:] = b0_pad
|
126 |
+
# conv-3x3
|
127 |
+
y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
|
128 |
+
return y1
|
129 |
+
|
130 |
+
def rep_params(self):
|
131 |
+
device = self.k0.get_device()
|
132 |
+
if device < 0:
|
133 |
+
device = None
|
134 |
+
|
135 |
+
if self.seq_type == 'conv1x1-conv3x3':
|
136 |
+
# re-param conv kernel
|
137 |
+
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
|
138 |
+
# re-param conv bias
|
139 |
+
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
140 |
+
rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
|
141 |
+
else:
|
142 |
+
tmp = self.scale * self.mask
|
143 |
+
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
|
144 |
+
for i in range(self.out_channels):
|
145 |
+
k1[i, i, :, :] = tmp[i, 0, :, :]
|
146 |
+
b1 = self.bias
|
147 |
+
# re-param conv kernel
|
148 |
+
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
|
149 |
+
# re-param conv bias
|
150 |
+
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
|
151 |
+
rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
|
152 |
+
return rep_weight, rep_bias
|
153 |
+
|
154 |
+
|
155 |
+
class ECB(nn.Module):
|
156 |
+
"""The ECB block used in the ECBSR architecture.
|
157 |
+
|
158 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
159 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
160 |
+
|
161 |
+
Args:
|
162 |
+
in_channels (int): Channel number of input.
|
163 |
+
out_channels (int): Channel number of output.
|
164 |
+
depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
|
165 |
+
act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
|
166 |
+
with_idt (bool): Whether to use identity connection. Default: False.
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
|
170 |
+
super(ECB, self).__init__()
|
171 |
+
|
172 |
+
self.depth_multiplier = depth_multiplier
|
173 |
+
self.in_channels = in_channels
|
174 |
+
self.out_channels = out_channels
|
175 |
+
self.act_type = act_type
|
176 |
+
|
177 |
+
if with_idt and (self.in_channels == self.out_channels):
|
178 |
+
self.with_idt = True
|
179 |
+
else:
|
180 |
+
self.with_idt = False
|
181 |
+
|
182 |
+
self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
|
183 |
+
self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
|
184 |
+
self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
|
185 |
+
self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
|
186 |
+
self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
|
187 |
+
|
188 |
+
if self.act_type == 'prelu':
|
189 |
+
self.act = nn.PReLU(num_parameters=self.out_channels)
|
190 |
+
elif self.act_type == 'relu':
|
191 |
+
self.act = nn.ReLU(inplace=True)
|
192 |
+
elif self.act_type == 'rrelu':
|
193 |
+
self.act = nn.RReLU(lower=-0.05, upper=0.05)
|
194 |
+
elif self.act_type == 'softplus':
|
195 |
+
self.act = nn.Softplus()
|
196 |
+
elif self.act_type == 'linear':
|
197 |
+
pass
|
198 |
+
else:
|
199 |
+
raise ValueError('The type of activation if not support!')
|
200 |
+
|
201 |
+
def forward(self, x):
|
202 |
+
if self.training:
|
203 |
+
y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
|
204 |
+
if self.with_idt:
|
205 |
+
y += x
|
206 |
+
else:
|
207 |
+
rep_weight, rep_bias = self.rep_params()
|
208 |
+
y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
|
209 |
+
if self.act_type != 'linear':
|
210 |
+
y = self.act(y)
|
211 |
+
return y
|
212 |
+
|
213 |
+
def rep_params(self):
|
214 |
+
weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
|
215 |
+
weight1, bias1 = self.conv1x1_3x3.rep_params()
|
216 |
+
weight2, bias2 = self.conv1x1_sbx.rep_params()
|
217 |
+
weight3, bias3 = self.conv1x1_sby.rep_params()
|
218 |
+
weight4, bias4 = self.conv1x1_lpl.rep_params()
|
219 |
+
rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
|
220 |
+
bias0 + bias1 + bias2 + bias3 + bias4)
|
221 |
+
|
222 |
+
if self.with_idt:
|
223 |
+
device = rep_weight.get_device()
|
224 |
+
if device < 0:
|
225 |
+
device = None
|
226 |
+
weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
|
227 |
+
for i in range(self.out_channels):
|
228 |
+
weight_idt[i, i, 1, 1] = 1.0
|
229 |
+
bias_idt = 0.0
|
230 |
+
rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
|
231 |
+
return rep_weight, rep_bias
|
232 |
+
|
233 |
+
|
234 |
+
@ARCH_REGISTRY.register()
|
235 |
+
class ECBSR(nn.Module):
|
236 |
+
"""ECBSR architecture.
|
237 |
+
|
238 |
+
Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
|
239 |
+
Ref git repo: https://github.com/xindongzhang/ECBSR
|
240 |
+
|
241 |
+
Args:
|
242 |
+
num_in_ch (int): Channel number of inputs.
|
243 |
+
num_out_ch (int): Channel number of outputs.
|
244 |
+
num_block (int): Block number in the trunk network.
|
245 |
+
num_channel (int): Channel number.
|
246 |
+
with_idt (bool): Whether use identity in convolution layers.
|
247 |
+
act_type (str): Activation type.
|
248 |
+
scale (int): Upsampling factor.
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
|
252 |
+
super(ECBSR, self).__init__()
|
253 |
+
self.num_in_ch = num_in_ch
|
254 |
+
self.scale = scale
|
255 |
+
|
256 |
+
backbone = []
|
257 |
+
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
258 |
+
for _ in range(num_block):
|
259 |
+
backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
|
260 |
+
backbone += [
|
261 |
+
ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
|
262 |
+
]
|
263 |
+
|
264 |
+
self.backbone = nn.Sequential(*backbone)
|
265 |
+
self.upsampler = nn.PixelShuffle(scale)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
if self.num_in_ch > 1:
|
269 |
+
shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
|
270 |
+
else:
|
271 |
+
shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
|
272 |
+
y = self.backbone(x) + shortcut
|
273 |
+
y = self.upsampler(y)
|
274 |
+
return y
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from r_basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class EDSR(nn.Module):
|
10 |
+
"""EDSR network structure.
|
11 |
+
|
12 |
+
Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
|
13 |
+
Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_in_ch (int): Channel number of inputs.
|
17 |
+
num_out_ch (int): Channel number of outputs.
|
18 |
+
num_feat (int): Channel number of intermediate features.
|
19 |
+
Default: 64.
|
20 |
+
num_block (int): Block number in the trunk network. Default: 16.
|
21 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
22 |
+
Default: 4.
|
23 |
+
res_scale (float): Used to scale the residual in residual block.
|
24 |
+
Default: 1.
|
25 |
+
img_range (float): Image range. Default: 255.
|
26 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
27 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
num_in_ch,
|
32 |
+
num_out_ch,
|
33 |
+
num_feat=64,
|
34 |
+
num_block=16,
|
35 |
+
upscale=4,
|
36 |
+
res_scale=1,
|
37 |
+
img_range=255.,
|
38 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
39 |
+
super(EDSR, self).__init__()
|
40 |
+
|
41 |
+
self.img_range = img_range
|
42 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
43 |
+
|
44 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
45 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
|
46 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
47 |
+
self.upsample = Upsample(upscale, num_feat)
|
48 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
self.mean = self.mean.type_as(x)
|
52 |
+
|
53 |
+
x = (x - self.mean) * self.img_range
|
54 |
+
x = self.conv_first(x)
|
55 |
+
res = self.conv_after_body(self.body(x))
|
56 |
+
res += x
|
57 |
+
|
58 |
+
x = self.conv_last(self.upsample(res))
|
59 |
+
x = x / self.img_range + self.mean
|
60 |
+
|
61 |
+
return x
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
|
7 |
+
|
8 |
+
|
9 |
+
class PCDAlignment(nn.Module):
|
10 |
+
"""Alignment module using Pyramid, Cascading and Deformable convolution
|
11 |
+
(PCD). It is used in EDVR.
|
12 |
+
|
13 |
+
Ref:
|
14 |
+
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
|
15 |
+
|
16 |
+
Args:
|
17 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
18 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_feat=64, deformable_groups=8):
|
22 |
+
super(PCDAlignment, self).__init__()
|
23 |
+
|
24 |
+
# Pyramid has three levels:
|
25 |
+
# L3: level 3, 1/4 spatial size
|
26 |
+
# L2: level 2, 1/2 spatial size
|
27 |
+
# L1: level 1, original spatial size
|
28 |
+
self.offset_conv1 = nn.ModuleDict()
|
29 |
+
self.offset_conv2 = nn.ModuleDict()
|
30 |
+
self.offset_conv3 = nn.ModuleDict()
|
31 |
+
self.dcn_pack = nn.ModuleDict()
|
32 |
+
self.feat_conv = nn.ModuleDict()
|
33 |
+
|
34 |
+
# Pyramids
|
35 |
+
for i in range(3, 0, -1):
|
36 |
+
level = f'l{i}'
|
37 |
+
self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
38 |
+
if i == 3:
|
39 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
40 |
+
else:
|
41 |
+
self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
42 |
+
self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
43 |
+
self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
44 |
+
|
45 |
+
if i < 3:
|
46 |
+
self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
47 |
+
|
48 |
+
# Cascading dcn
|
49 |
+
self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
50 |
+
self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
51 |
+
self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
|
52 |
+
|
53 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
54 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
55 |
+
|
56 |
+
def forward(self, nbr_feat_l, ref_feat_l):
|
57 |
+
"""Align neighboring frame features to the reference frame features.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
nbr_feat_l (list[Tensor]): Neighboring feature list. It
|
61 |
+
contains three pyramid levels (L1, L2, L3),
|
62 |
+
each with shape (b, c, h, w).
|
63 |
+
ref_feat_l (list[Tensor]): Reference feature list. It
|
64 |
+
contains three pyramid levels (L1, L2, L3),
|
65 |
+
each with shape (b, c, h, w).
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Tensor: Aligned features.
|
69 |
+
"""
|
70 |
+
# Pyramids
|
71 |
+
upsampled_offset, upsampled_feat = None, None
|
72 |
+
for i in range(3, 0, -1):
|
73 |
+
level = f'l{i}'
|
74 |
+
offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
|
75 |
+
offset = self.lrelu(self.offset_conv1[level](offset))
|
76 |
+
if i == 3:
|
77 |
+
offset = self.lrelu(self.offset_conv2[level](offset))
|
78 |
+
else:
|
79 |
+
offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
|
80 |
+
offset = self.lrelu(self.offset_conv3[level](offset))
|
81 |
+
|
82 |
+
feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
|
83 |
+
if i < 3:
|
84 |
+
feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
|
85 |
+
if i > 1:
|
86 |
+
feat = self.lrelu(feat)
|
87 |
+
|
88 |
+
if i > 1: # upsample offset and features
|
89 |
+
# x2: when we upsample the offset, we should also enlarge
|
90 |
+
# the magnitude.
|
91 |
+
upsampled_offset = self.upsample(offset) * 2
|
92 |
+
upsampled_feat = self.upsample(feat)
|
93 |
+
|
94 |
+
# Cascading
|
95 |
+
offset = torch.cat([feat, ref_feat_l[0]], dim=1)
|
96 |
+
offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
|
97 |
+
feat = self.lrelu(self.cas_dcnpack(feat, offset))
|
98 |
+
return feat
|
99 |
+
|
100 |
+
|
101 |
+
class TSAFusion(nn.Module):
|
102 |
+
"""Temporal Spatial Attention (TSA) fusion module.
|
103 |
+
|
104 |
+
Temporal: Calculate the correlation between center frame and
|
105 |
+
neighboring frames;
|
106 |
+
Spatial: It has 3 pyramid levels, the attention is similar to SFT.
|
107 |
+
(SFT: Recovering realistic texture in image super-resolution by deep
|
108 |
+
spatial feature transform.)
|
109 |
+
|
110 |
+
Args:
|
111 |
+
num_feat (int): Channel number of middle features. Default: 64.
|
112 |
+
num_frame (int): Number of frames. Default: 5.
|
113 |
+
center_frame_idx (int): The index of center frame. Default: 2.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
|
117 |
+
super(TSAFusion, self).__init__()
|
118 |
+
self.center_frame_idx = center_frame_idx
|
119 |
+
# temporal attention (before fusion conv)
|
120 |
+
self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
122 |
+
self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
123 |
+
|
124 |
+
# spatial attention (after fusion conv)
|
125 |
+
self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
|
126 |
+
self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
|
127 |
+
self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
|
128 |
+
self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
|
129 |
+
self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
130 |
+
self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
|
131 |
+
self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
132 |
+
self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
|
133 |
+
self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
|
134 |
+
self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
135 |
+
self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
|
136 |
+
self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
|
137 |
+
|
138 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
139 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
140 |
+
|
141 |
+
def forward(self, aligned_feat):
|
142 |
+
"""
|
143 |
+
Args:
|
144 |
+
aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Tensor: Features after TSA with the shape (b, c, h, w).
|
148 |
+
"""
|
149 |
+
b, t, c, h, w = aligned_feat.size()
|
150 |
+
# temporal attention
|
151 |
+
embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
|
152 |
+
embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
|
153 |
+
embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
|
154 |
+
|
155 |
+
corr_l = [] # correlation list
|
156 |
+
for i in range(t):
|
157 |
+
emb_neighbor = embedding[:, i, :, :, :]
|
158 |
+
corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
|
159 |
+
corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
|
160 |
+
corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
|
161 |
+
corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
|
162 |
+
corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
|
163 |
+
aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
|
164 |
+
|
165 |
+
# fusion
|
166 |
+
feat = self.lrelu(self.feat_fusion(aligned_feat))
|
167 |
+
|
168 |
+
# spatial attention
|
169 |
+
attn = self.lrelu(self.spatial_attn1(aligned_feat))
|
170 |
+
attn_max = self.max_pool(attn)
|
171 |
+
attn_avg = self.avg_pool(attn)
|
172 |
+
attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
|
173 |
+
# pyramid levels
|
174 |
+
attn_level = self.lrelu(self.spatial_attn_l1(attn))
|
175 |
+
attn_max = self.max_pool(attn_level)
|
176 |
+
attn_avg = self.avg_pool(attn_level)
|
177 |
+
attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
|
178 |
+
attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
|
179 |
+
attn_level = self.upsample(attn_level)
|
180 |
+
|
181 |
+
attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
|
182 |
+
attn = self.lrelu(self.spatial_attn4(attn))
|
183 |
+
attn = self.upsample(attn)
|
184 |
+
attn = self.spatial_attn5(attn)
|
185 |
+
attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
|
186 |
+
attn = torch.sigmoid(attn)
|
187 |
+
|
188 |
+
# after initialization, * 2 makes (attn * 2) to be close to 1.
|
189 |
+
feat = feat * attn * 2 + attn_add
|
190 |
+
return feat
|
191 |
+
|
192 |
+
|
193 |
+
class PredeblurModule(nn.Module):
|
194 |
+
"""Pre-dublur module.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
198 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
199 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
|
203 |
+
super(PredeblurModule, self).__init__()
|
204 |
+
self.hr_in = hr_in
|
205 |
+
|
206 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
207 |
+
if self.hr_in:
|
208 |
+
# downsample x4 by stride conv
|
209 |
+
self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
210 |
+
self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
211 |
+
|
212 |
+
# generate feature pyramid
|
213 |
+
self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
214 |
+
self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
215 |
+
|
216 |
+
self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
|
217 |
+
self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
|
218 |
+
self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
|
219 |
+
self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
|
220 |
+
|
221 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
222 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
223 |
+
|
224 |
+
def forward(self, x):
|
225 |
+
feat_l1 = self.lrelu(self.conv_first(x))
|
226 |
+
if self.hr_in:
|
227 |
+
feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
|
228 |
+
feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
|
229 |
+
|
230 |
+
# generate feature pyramid
|
231 |
+
feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
|
232 |
+
feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
|
233 |
+
|
234 |
+
feat_l3 = self.upsample(self.resblock_l3(feat_l3))
|
235 |
+
feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
|
236 |
+
feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
|
237 |
+
|
238 |
+
for i in range(2):
|
239 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
240 |
+
feat_l1 = feat_l1 + feat_l2
|
241 |
+
for i in range(2, 5):
|
242 |
+
feat_l1 = self.resblock_l1[i](feat_l1)
|
243 |
+
return feat_l1
|
244 |
+
|
245 |
+
|
246 |
+
@ARCH_REGISTRY.register()
|
247 |
+
class EDVR(nn.Module):
|
248 |
+
"""EDVR network structure for video super-resolution.
|
249 |
+
|
250 |
+
Now only support X4 upsampling factor.
|
251 |
+
Paper:
|
252 |
+
EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
|
253 |
+
|
254 |
+
Args:
|
255 |
+
num_in_ch (int): Channel number of input image. Default: 3.
|
256 |
+
num_out_ch (int): Channel number of output image. Default: 3.
|
257 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
258 |
+
num_frame (int): Number of input frames. Default: 5.
|
259 |
+
deformable_groups (int): Deformable groups. Defaults: 8.
|
260 |
+
num_extract_block (int): Number of blocks for feature extraction.
|
261 |
+
Default: 5.
|
262 |
+
num_reconstruct_block (int): Number of blocks for reconstruction.
|
263 |
+
Default: 10.
|
264 |
+
center_frame_idx (int): The index of center frame. Frame counting from
|
265 |
+
0. Default: Middle of input frames.
|
266 |
+
hr_in (bool): Whether the input has high resolution. Default: False.
|
267 |
+
with_predeblur (bool): Whether has predeblur module.
|
268 |
+
Default: False.
|
269 |
+
with_tsa (bool): Whether has TSA module. Default: True.
|
270 |
+
"""
|
271 |
+
|
272 |
+
def __init__(self,
|
273 |
+
num_in_ch=3,
|
274 |
+
num_out_ch=3,
|
275 |
+
num_feat=64,
|
276 |
+
num_frame=5,
|
277 |
+
deformable_groups=8,
|
278 |
+
num_extract_block=5,
|
279 |
+
num_reconstruct_block=10,
|
280 |
+
center_frame_idx=None,
|
281 |
+
hr_in=False,
|
282 |
+
with_predeblur=False,
|
283 |
+
with_tsa=True):
|
284 |
+
super(EDVR, self).__init__()
|
285 |
+
if center_frame_idx is None:
|
286 |
+
self.center_frame_idx = num_frame // 2
|
287 |
+
else:
|
288 |
+
self.center_frame_idx = center_frame_idx
|
289 |
+
self.hr_in = hr_in
|
290 |
+
self.with_predeblur = with_predeblur
|
291 |
+
self.with_tsa = with_tsa
|
292 |
+
|
293 |
+
# extract features for each frame
|
294 |
+
if self.with_predeblur:
|
295 |
+
self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
|
296 |
+
self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
|
297 |
+
else:
|
298 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
299 |
+
|
300 |
+
# extract pyramid features
|
301 |
+
self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
|
302 |
+
self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
303 |
+
self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
304 |
+
self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
|
305 |
+
self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
306 |
+
|
307 |
+
# pcd and tsa module
|
308 |
+
self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
|
309 |
+
if self.with_tsa:
|
310 |
+
self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
|
311 |
+
else:
|
312 |
+
self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
|
313 |
+
|
314 |
+
# reconstruction
|
315 |
+
self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
|
316 |
+
# upsample
|
317 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
318 |
+
self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
|
319 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
320 |
+
self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
|
321 |
+
self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
|
322 |
+
|
323 |
+
# activation function
|
324 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
325 |
+
|
326 |
+
def forward(self, x):
|
327 |
+
b, t, c, h, w = x.size()
|
328 |
+
if self.hr_in:
|
329 |
+
assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
|
330 |
+
else:
|
331 |
+
assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
|
332 |
+
|
333 |
+
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
|
334 |
+
|
335 |
+
# extract features for each frame
|
336 |
+
# L1
|
337 |
+
if self.with_predeblur:
|
338 |
+
feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
|
339 |
+
if self.hr_in:
|
340 |
+
h, w = h // 4, w // 4
|
341 |
+
else:
|
342 |
+
feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
|
343 |
+
|
344 |
+
feat_l1 = self.feature_extraction(feat_l1)
|
345 |
+
# L2
|
346 |
+
feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
|
347 |
+
feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
|
348 |
+
# L3
|
349 |
+
feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
|
350 |
+
feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
|
351 |
+
|
352 |
+
feat_l1 = feat_l1.view(b, t, -1, h, w)
|
353 |
+
feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
|
354 |
+
feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
|
355 |
+
|
356 |
+
# PCD alignment
|
357 |
+
ref_feat_l = [ # reference feature list
|
358 |
+
feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
|
359 |
+
feat_l3[:, self.center_frame_idx, :, :, :].clone()
|
360 |
+
]
|
361 |
+
aligned_feat = []
|
362 |
+
for i in range(t):
|
363 |
+
nbr_feat_l = [ # neighboring feature list
|
364 |
+
feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
|
365 |
+
]
|
366 |
+
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
|
367 |
+
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
|
368 |
+
|
369 |
+
if not self.with_tsa:
|
370 |
+
aligned_feat = aligned_feat.view(b, -1, h, w)
|
371 |
+
feat = self.fusion(aligned_feat)
|
372 |
+
|
373 |
+
out = self.reconstruction(feat)
|
374 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
375 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
376 |
+
out = self.lrelu(self.conv_hr(out))
|
377 |
+
out = self.conv_last(out)
|
378 |
+
if self.hr_in:
|
379 |
+
base = x_center
|
380 |
+
else:
|
381 |
+
base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
|
382 |
+
out += base
|
383 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
|
8 |
+
|
9 |
+
|
10 |
+
class SPADEGenerator(BaseNetwork):
|
11 |
+
"""Generator with SPADEResBlock"""
|
12 |
+
|
13 |
+
def __init__(self,
|
14 |
+
num_in_ch=3,
|
15 |
+
num_feat=64,
|
16 |
+
use_vae=False,
|
17 |
+
z_dim=256,
|
18 |
+
crop_size=512,
|
19 |
+
norm_g='spectralspadesyncbatch3x3',
|
20 |
+
is_train=True,
|
21 |
+
init_train_phase=3): # progressive training disabled
|
22 |
+
super().__init__()
|
23 |
+
self.nf = num_feat
|
24 |
+
self.input_nc = num_in_ch
|
25 |
+
self.is_train = is_train
|
26 |
+
self.train_phase = init_train_phase
|
27 |
+
|
28 |
+
self.scale_ratio = 5 # hardcoded now
|
29 |
+
self.sw = crop_size // (2**self.scale_ratio)
|
30 |
+
self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
|
31 |
+
|
32 |
+
if use_vae:
|
33 |
+
# In case of VAE, we will sample from random z vector
|
34 |
+
self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
|
35 |
+
else:
|
36 |
+
# Otherwise, we make the network deterministic by starting with
|
37 |
+
# downsampled segmentation map instead of random z
|
38 |
+
self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
|
39 |
+
|
40 |
+
self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
41 |
+
|
42 |
+
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
43 |
+
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
|
44 |
+
|
45 |
+
self.ups = nn.ModuleList([
|
46 |
+
SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
|
47 |
+
SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
|
48 |
+
SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
|
49 |
+
SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
|
50 |
+
])
|
51 |
+
|
52 |
+
self.to_rgbs = nn.ModuleList([
|
53 |
+
nn.Conv2d(8 * self.nf, 3, 3, padding=1),
|
54 |
+
nn.Conv2d(4 * self.nf, 3, 3, padding=1),
|
55 |
+
nn.Conv2d(2 * self.nf, 3, 3, padding=1),
|
56 |
+
nn.Conv2d(1 * self.nf, 3, 3, padding=1)
|
57 |
+
])
|
58 |
+
|
59 |
+
self.up = nn.Upsample(scale_factor=2)
|
60 |
+
|
61 |
+
def encode(self, input_tensor):
|
62 |
+
"""
|
63 |
+
Encode input_tensor into feature maps, can be overridden in derived classes
|
64 |
+
Default: nearest downsampling of 2**5 = 32 times
|
65 |
+
"""
|
66 |
+
h, w = input_tensor.size()[-2:]
|
67 |
+
sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
|
68 |
+
x = F.interpolate(input_tensor, size=(sh, sw))
|
69 |
+
return self.fc(x)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
# In oroginal SPADE, seg means a segmentation map, but here we use x instead.
|
73 |
+
seg = x
|
74 |
+
|
75 |
+
x = self.encode(x)
|
76 |
+
x = self.head_0(x, seg)
|
77 |
+
|
78 |
+
x = self.up(x)
|
79 |
+
x = self.g_middle_0(x, seg)
|
80 |
+
x = self.g_middle_1(x, seg)
|
81 |
+
|
82 |
+
if self.is_train:
|
83 |
+
phase = self.train_phase + 1
|
84 |
+
else:
|
85 |
+
phase = len(self.to_rgbs)
|
86 |
+
|
87 |
+
for i in range(phase):
|
88 |
+
x = self.up(x)
|
89 |
+
x = self.ups[i](x, seg)
|
90 |
+
|
91 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
92 |
+
x = torch.tanh(x)
|
93 |
+
|
94 |
+
return x
|
95 |
+
|
96 |
+
def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
|
97 |
+
"""
|
98 |
+
A helper class for subspace visualization. Input and seg are different images.
|
99 |
+
For the first n levels (including encoder) we use input, for the rest we use seg.
|
100 |
+
|
101 |
+
If mode = 'progressive', the output's like: AAABBB
|
102 |
+
If mode = 'one_plug', the output's like: AAABAA
|
103 |
+
If mode = 'one_ablate', the output's like: BBBABB
|
104 |
+
"""
|
105 |
+
|
106 |
+
if seg is None:
|
107 |
+
return self.forward(input_x)
|
108 |
+
|
109 |
+
if self.is_train:
|
110 |
+
phase = self.train_phase + 1
|
111 |
+
else:
|
112 |
+
phase = len(self.to_rgbs)
|
113 |
+
|
114 |
+
if mode == 'progressive':
|
115 |
+
n = max(min(n, 4 + phase), 0)
|
116 |
+
guide_list = [input_x] * n + [seg] * (4 + phase - n)
|
117 |
+
elif mode == 'one_plug':
|
118 |
+
n = max(min(n, 4 + phase - 1), 0)
|
119 |
+
guide_list = [seg] * (4 + phase)
|
120 |
+
guide_list[n] = input_x
|
121 |
+
elif mode == 'one_ablate':
|
122 |
+
if n > 3 + phase:
|
123 |
+
return self.forward(input_x)
|
124 |
+
guide_list = [input_x] * (4 + phase)
|
125 |
+
guide_list[n] = seg
|
126 |
+
|
127 |
+
x = self.encode(guide_list[0])
|
128 |
+
x = self.head_0(x, guide_list[1])
|
129 |
+
|
130 |
+
x = self.up(x)
|
131 |
+
x = self.g_middle_0(x, guide_list[2])
|
132 |
+
x = self.g_middle_1(x, guide_list[3])
|
133 |
+
|
134 |
+
for i in range(phase):
|
135 |
+
x = self.up(x)
|
136 |
+
x = self.ups[i](x, guide_list[4 + i])
|
137 |
+
|
138 |
+
x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
|
139 |
+
x = torch.tanh(x)
|
140 |
+
|
141 |
+
return x
|
142 |
+
|
143 |
+
|
144 |
+
@ARCH_REGISTRY.register()
|
145 |
+
class HiFaceGAN(SPADEGenerator):
|
146 |
+
"""
|
147 |
+
HiFaceGAN: SPADEGenerator with a learnable feature encoder
|
148 |
+
Current encoder design: LIPEncoder
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self,
|
152 |
+
num_in_ch=3,
|
153 |
+
num_feat=64,
|
154 |
+
use_vae=False,
|
155 |
+
z_dim=256,
|
156 |
+
crop_size=512,
|
157 |
+
norm_g='spectralspadesyncbatch3x3',
|
158 |
+
is_train=True,
|
159 |
+
init_train_phase=3):
|
160 |
+
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
|
161 |
+
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
|
162 |
+
|
163 |
+
def encode(self, input_tensor):
|
164 |
+
return self.lip_encoder(input_tensor)
|
165 |
+
|
166 |
+
|
167 |
+
@ARCH_REGISTRY.register()
|
168 |
+
class HiFaceGANDiscriminator(BaseNetwork):
|
169 |
+
"""
|
170 |
+
Inspired by pix2pixHD multiscale discriminator.
|
171 |
+
Args:
|
172 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
173 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
174 |
+
conditional_d (bool): Whether use conditional discriminator.
|
175 |
+
Default: True.
|
176 |
+
num_d (int): Number of Multiscale discriminators. Default: 3.
|
177 |
+
n_layers_d (int): Number of downsample layers in each D. Default: 4.
|
178 |
+
num_feat (int): Channel number of base intermediate features.
|
179 |
+
Default: 64.
|
180 |
+
norm_d (str): String to determine normalization layers in D.
|
181 |
+
Choices: [spectral][instance/batch/syncbatch]
|
182 |
+
Default: 'spectralinstance'.
|
183 |
+
keep_features (bool): Keep intermediate features for matching loss, etc.
|
184 |
+
Default: True.
|
185 |
+
"""
|
186 |
+
|
187 |
+
def __init__(self,
|
188 |
+
num_in_ch=3,
|
189 |
+
num_out_ch=3,
|
190 |
+
conditional_d=True,
|
191 |
+
num_d=2,
|
192 |
+
n_layers_d=4,
|
193 |
+
num_feat=64,
|
194 |
+
norm_d='spectralinstance',
|
195 |
+
keep_features=True):
|
196 |
+
super().__init__()
|
197 |
+
self.num_d = num_d
|
198 |
+
|
199 |
+
input_nc = num_in_ch
|
200 |
+
if conditional_d:
|
201 |
+
input_nc += num_out_ch
|
202 |
+
|
203 |
+
for i in range(num_d):
|
204 |
+
subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
|
205 |
+
self.add_module(f'discriminator_{i}', subnet_d)
|
206 |
+
|
207 |
+
def downsample(self, x):
|
208 |
+
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
209 |
+
|
210 |
+
# Returns list of lists of discriminator outputs.
|
211 |
+
# The final result is of size opt.num_d x opt.n_layers_D
|
212 |
+
def forward(self, x):
|
213 |
+
result = []
|
214 |
+
for _, _net_d in self.named_children():
|
215 |
+
out = _net_d(x)
|
216 |
+
result.append(out)
|
217 |
+
x = self.downsample(x)
|
218 |
+
|
219 |
+
return result
|
220 |
+
|
221 |
+
|
222 |
+
class NLayerDiscriminator(BaseNetwork):
|
223 |
+
"""Defines the PatchGAN discriminator with the specified arguments."""
|
224 |
+
|
225 |
+
def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
|
226 |
+
super().__init__()
|
227 |
+
kw = 4
|
228 |
+
padw = int(np.ceil((kw - 1.0) / 2))
|
229 |
+
nf = num_feat
|
230 |
+
self.keep_features = keep_features
|
231 |
+
|
232 |
+
norm_layer = get_nonspade_norm_layer(norm_d)
|
233 |
+
sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
|
234 |
+
|
235 |
+
for n in range(1, n_layers_d):
|
236 |
+
nf_prev = nf
|
237 |
+
nf = min(nf * 2, 512)
|
238 |
+
stride = 1 if n == n_layers_d - 1 else 2
|
239 |
+
sequence += [[
|
240 |
+
norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
|
241 |
+
nn.LeakyReLU(0.2, False)
|
242 |
+
]]
|
243 |
+
|
244 |
+
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
|
245 |
+
|
246 |
+
# We divide the layers into groups to extract intermediate layer outputs
|
247 |
+
for n in range(len(sequence)):
|
248 |
+
self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
results = [x]
|
252 |
+
for submodel in self.children():
|
253 |
+
intermediate_output = submodel(results[-1])
|
254 |
+
results.append(intermediate_output)
|
255 |
+
|
256 |
+
if self.keep_features:
|
257 |
+
return results[1:]
|
258 |
+
else:
|
259 |
+
return results[-1]
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import init
|
6 |
+
# Warning: spectral norm could be buggy
|
7 |
+
# under eval mode and multi-GPU inference
|
8 |
+
# A workaround is sticking to single-GPU inference and train mode
|
9 |
+
from torch.nn.utils import spectral_norm
|
10 |
+
|
11 |
+
|
12 |
+
class SPADE(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, config_text, norm_nc, label_nc):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
assert config_text.startswith('spade')
|
18 |
+
parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
|
19 |
+
param_free_norm_type = str(parsed.group(1))
|
20 |
+
ks = int(parsed.group(2))
|
21 |
+
|
22 |
+
if param_free_norm_type == 'instance':
|
23 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
24 |
+
elif param_free_norm_type == 'syncbatch':
|
25 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
26 |
+
self.param_free_norm = nn.InstanceNorm2d(norm_nc)
|
27 |
+
elif param_free_norm_type == 'batch':
|
28 |
+
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
|
29 |
+
else:
|
30 |
+
raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
|
31 |
+
|
32 |
+
# The dimension of the intermediate embedding space. Yes, hardcoded.
|
33 |
+
nhidden = 128 if norm_nc > 128 else norm_nc
|
34 |
+
|
35 |
+
pw = ks // 2
|
36 |
+
self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
|
37 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
38 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
|
39 |
+
|
40 |
+
def forward(self, x, segmap):
|
41 |
+
|
42 |
+
# Part 1. generate parameter-free normalized activations
|
43 |
+
normalized = self.param_free_norm(x)
|
44 |
+
|
45 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
46 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
47 |
+
actv = self.mlp_shared(segmap)
|
48 |
+
gamma = self.mlp_gamma(actv)
|
49 |
+
beta = self.mlp_beta(actv)
|
50 |
+
|
51 |
+
# apply scale and bias
|
52 |
+
out = normalized * gamma + beta
|
53 |
+
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class SPADEResnetBlock(nn.Module):
|
58 |
+
"""
|
59 |
+
ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
|
60 |
+
it takes in the segmentation map as input, learns the skip connection if necessary,
|
61 |
+
and applies normalization first and then convolution.
|
62 |
+
This architecture seemed like a standard architecture for unconditional or
|
63 |
+
class-conditional GAN architecture using residual block.
|
64 |
+
The code was inspired from https://github.com/LMescheder/GAN_stability.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
|
68 |
+
super().__init__()
|
69 |
+
# Attributes
|
70 |
+
self.learned_shortcut = (fin != fout)
|
71 |
+
fmiddle = min(fin, fout)
|
72 |
+
|
73 |
+
# create conv layers
|
74 |
+
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
|
75 |
+
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
|
76 |
+
if self.learned_shortcut:
|
77 |
+
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
|
78 |
+
|
79 |
+
# apply spectral norm if specified
|
80 |
+
if 'spectral' in norm_g:
|
81 |
+
self.conv_0 = spectral_norm(self.conv_0)
|
82 |
+
self.conv_1 = spectral_norm(self.conv_1)
|
83 |
+
if self.learned_shortcut:
|
84 |
+
self.conv_s = spectral_norm(self.conv_s)
|
85 |
+
|
86 |
+
# define normalization layers
|
87 |
+
spade_config_str = norm_g.replace('spectral', '')
|
88 |
+
self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
|
89 |
+
self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
|
90 |
+
if self.learned_shortcut:
|
91 |
+
self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
|
92 |
+
|
93 |
+
# note the resnet block with SPADE also takes in |seg|,
|
94 |
+
# the semantic segmentation map as input
|
95 |
+
def forward(self, x, seg):
|
96 |
+
x_s = self.shortcut(x, seg)
|
97 |
+
dx = self.conv_0(self.act(self.norm_0(x, seg)))
|
98 |
+
dx = self.conv_1(self.act(self.norm_1(dx, seg)))
|
99 |
+
out = x_s + dx
|
100 |
+
return out
|
101 |
+
|
102 |
+
def shortcut(self, x, seg):
|
103 |
+
if self.learned_shortcut:
|
104 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
105 |
+
else:
|
106 |
+
x_s = x
|
107 |
+
return x_s
|
108 |
+
|
109 |
+
def act(self, x):
|
110 |
+
return F.leaky_relu(x, 2e-1)
|
111 |
+
|
112 |
+
|
113 |
+
class BaseNetwork(nn.Module):
|
114 |
+
""" A basis for hifacegan archs with custom initialization """
|
115 |
+
|
116 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
117 |
+
|
118 |
+
def init_func(m):
|
119 |
+
classname = m.__class__.__name__
|
120 |
+
if classname.find('BatchNorm2d') != -1:
|
121 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
122 |
+
init.normal_(m.weight.data, 1.0, gain)
|
123 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
124 |
+
init.constant_(m.bias.data, 0.0)
|
125 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
126 |
+
if init_type == 'normal':
|
127 |
+
init.normal_(m.weight.data, 0.0, gain)
|
128 |
+
elif init_type == 'xavier':
|
129 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
130 |
+
elif init_type == 'xavier_uniform':
|
131 |
+
init.xavier_uniform_(m.weight.data, gain=1.0)
|
132 |
+
elif init_type == 'kaiming':
|
133 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
134 |
+
elif init_type == 'orthogonal':
|
135 |
+
init.orthogonal_(m.weight.data, gain=gain)
|
136 |
+
elif init_type == 'none': # uses pytorch's default init method
|
137 |
+
m.reset_parameters()
|
138 |
+
else:
|
139 |
+
raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
|
140 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
141 |
+
init.constant_(m.bias.data, 0.0)
|
142 |
+
|
143 |
+
self.apply(init_func)
|
144 |
+
|
145 |
+
# propagate to children
|
146 |
+
for m in self.children():
|
147 |
+
if hasattr(m, 'init_weights'):
|
148 |
+
m.init_weights(init_type, gain)
|
149 |
+
|
150 |
+
def forward(self, x):
|
151 |
+
pass
|
152 |
+
|
153 |
+
|
154 |
+
def lip2d(x, logit, kernel=3, stride=2, padding=1):
|
155 |
+
weight = logit.exp()
|
156 |
+
return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
|
157 |
+
|
158 |
+
|
159 |
+
class SoftGate(nn.Module):
|
160 |
+
COEFF = 12.0
|
161 |
+
|
162 |
+
def forward(self, x):
|
163 |
+
return torch.sigmoid(x).mul(self.COEFF)
|
164 |
+
|
165 |
+
|
166 |
+
class SimplifiedLIP(nn.Module):
|
167 |
+
|
168 |
+
def __init__(self, channels):
|
169 |
+
super(SimplifiedLIP, self).__init__()
|
170 |
+
self.logit = nn.Sequential(
|
171 |
+
nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
|
172 |
+
SoftGate())
|
173 |
+
|
174 |
+
def init_layer(self):
|
175 |
+
self.logit[0].weight.data.fill_(0.0)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
frac = lip2d(x, self.logit(x))
|
179 |
+
return frac
|
180 |
+
|
181 |
+
|
182 |
+
class LIPEncoder(BaseNetwork):
|
183 |
+
"""Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
|
184 |
+
|
185 |
+
def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
|
186 |
+
super().__init__()
|
187 |
+
self.sw = sw
|
188 |
+
self.sh = sh
|
189 |
+
self.max_ratio = 16
|
190 |
+
# 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
|
191 |
+
kw = 3
|
192 |
+
pw = (kw - 1) // 2
|
193 |
+
|
194 |
+
model = [
|
195 |
+
nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
|
196 |
+
norm_layer(ngf),
|
197 |
+
nn.ReLU(),
|
198 |
+
]
|
199 |
+
cur_ratio = 1
|
200 |
+
for i in range(n_2xdown):
|
201 |
+
next_ratio = min(cur_ratio * 2, self.max_ratio)
|
202 |
+
model += [
|
203 |
+
SimplifiedLIP(ngf * cur_ratio),
|
204 |
+
nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
|
205 |
+
norm_layer(ngf * next_ratio),
|
206 |
+
]
|
207 |
+
cur_ratio = next_ratio
|
208 |
+
if i < n_2xdown - 1:
|
209 |
+
model += [nn.ReLU(inplace=True)]
|
210 |
+
|
211 |
+
self.model = nn.Sequential(*model)
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
return self.model(x)
|
215 |
+
|
216 |
+
|
217 |
+
def get_nonspade_norm_layer(norm_type='instance'):
|
218 |
+
# helper function to get # output channels of the previous layer
|
219 |
+
def get_out_channel(layer):
|
220 |
+
if hasattr(layer, 'out_channels'):
|
221 |
+
return getattr(layer, 'out_channels')
|
222 |
+
return layer.weight.size(0)
|
223 |
+
|
224 |
+
# this function will be returned
|
225 |
+
def add_norm_layer(layer):
|
226 |
+
nonlocal norm_type
|
227 |
+
if norm_type.startswith('spectral'):
|
228 |
+
layer = spectral_norm(layer)
|
229 |
+
subnorm_type = norm_type[len('spectral'):]
|
230 |
+
|
231 |
+
if subnorm_type == 'none' or len(subnorm_type) == 0:
|
232 |
+
return layer
|
233 |
+
|
234 |
+
# remove bias in the previous layer, which is meaningless
|
235 |
+
# since it has no effect after normalization
|
236 |
+
if getattr(layer, 'bias', None) is not None:
|
237 |
+
delattr(layer, 'bias')
|
238 |
+
layer.register_parameter('bias', None)
|
239 |
+
|
240 |
+
if subnorm_type == 'batch':
|
241 |
+
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
|
242 |
+
elif subnorm_type == 'sync_batch':
|
243 |
+
print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
|
244 |
+
# norm_layer = SynchronizedBatchNorm2d(
|
245 |
+
# get_out_channel(layer), affine=True)
|
246 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
247 |
+
elif subnorm_type == 'instance':
|
248 |
+
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
|
249 |
+
else:
|
250 |
+
raise ValueError(f'normalization layer {subnorm_type} is not recognized')
|
251 |
+
|
252 |
+
return nn.Sequential(layer, norm_layer)
|
253 |
+
|
254 |
+
print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
|
255 |
+
return add_norm_layer
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
|
2 |
+
# For FID metric
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils.model_zoo import load_url
|
9 |
+
from torchvision import models
|
10 |
+
|
11 |
+
# Inception weights ported to Pytorch from
|
12 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
13 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
14 |
+
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
15 |
+
|
16 |
+
|
17 |
+
class InceptionV3(nn.Module):
|
18 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
19 |
+
|
20 |
+
# Index of default block of inception to return,
|
21 |
+
# corresponds to output of final average pooling
|
22 |
+
DEFAULT_BLOCK_INDEX = 3
|
23 |
+
|
24 |
+
# Maps feature dimensionality to their output blocks indices
|
25 |
+
BLOCK_INDEX_BY_DIM = {
|
26 |
+
64: 0, # First max pooling features
|
27 |
+
192: 1, # Second max pooling features
|
28 |
+
768: 2, # Pre-aux classifier features
|
29 |
+
2048: 3 # Final average pooling features
|
30 |
+
}
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
output_blocks=(DEFAULT_BLOCK_INDEX),
|
34 |
+
resize_input=True,
|
35 |
+
normalize_input=True,
|
36 |
+
requires_grad=False,
|
37 |
+
use_fid_inception=True):
|
38 |
+
"""Build pretrained InceptionV3.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
output_blocks (list[int]): Indices of blocks to return features of.
|
42 |
+
Possible values are:
|
43 |
+
- 0: corresponds to output of first max pooling
|
44 |
+
- 1: corresponds to output of second max pooling
|
45 |
+
- 2: corresponds to output which is fed to aux classifier
|
46 |
+
- 3: corresponds to output of final average pooling
|
47 |
+
resize_input (bool): If true, bilinearly resizes input to width and
|
48 |
+
height 299 before feeding input to model. As the network
|
49 |
+
without fully connected layers is fully convolutional, it
|
50 |
+
should be able to handle inputs of arbitrary size, so resizing
|
51 |
+
might not be strictly needed. Default: True.
|
52 |
+
normalize_input (bool): If true, scales the input from range (0, 1)
|
53 |
+
to the range the pretrained Inception network expects,
|
54 |
+
namely (-1, 1). Default: True.
|
55 |
+
requires_grad (bool): If true, parameters of the model require
|
56 |
+
gradients. Possibly useful for finetuning the network.
|
57 |
+
Default: False.
|
58 |
+
use_fid_inception (bool): If true, uses the pretrained Inception
|
59 |
+
model used in Tensorflow's FID implementation.
|
60 |
+
If false, uses the pretrained Inception model available in
|
61 |
+
torchvision. The FID Inception model has different weights
|
62 |
+
and a slightly different structure from torchvision's
|
63 |
+
Inception model. If you want to compute FID scores, you are
|
64 |
+
strongly advised to set this parameter to true to get
|
65 |
+
comparable results. Default: True.
|
66 |
+
"""
|
67 |
+
super(InceptionV3, self).__init__()
|
68 |
+
|
69 |
+
self.resize_input = resize_input
|
70 |
+
self.normalize_input = normalize_input
|
71 |
+
self.output_blocks = sorted(output_blocks)
|
72 |
+
self.last_needed_block = max(output_blocks)
|
73 |
+
|
74 |
+
assert self.last_needed_block <= 3, ('Last possible output block index is 3')
|
75 |
+
|
76 |
+
self.blocks = nn.ModuleList()
|
77 |
+
|
78 |
+
if use_fid_inception:
|
79 |
+
inception = fid_inception_v3()
|
80 |
+
else:
|
81 |
+
try:
|
82 |
+
inception = models.inception_v3(pretrained=True, init_weights=False)
|
83 |
+
except TypeError:
|
84 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
85 |
+
inception = models.inception_v3(pretrained=True)
|
86 |
+
|
87 |
+
# Block 0: input to maxpool1
|
88 |
+
block0 = [
|
89 |
+
inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
|
90 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
91 |
+
]
|
92 |
+
self.blocks.append(nn.Sequential(*block0))
|
93 |
+
|
94 |
+
# Block 1: maxpool1 to maxpool2
|
95 |
+
if self.last_needed_block >= 1:
|
96 |
+
block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
|
97 |
+
self.blocks.append(nn.Sequential(*block1))
|
98 |
+
|
99 |
+
# Block 2: maxpool2 to aux classifier
|
100 |
+
if self.last_needed_block >= 2:
|
101 |
+
block2 = [
|
102 |
+
inception.Mixed_5b,
|
103 |
+
inception.Mixed_5c,
|
104 |
+
inception.Mixed_5d,
|
105 |
+
inception.Mixed_6a,
|
106 |
+
inception.Mixed_6b,
|
107 |
+
inception.Mixed_6c,
|
108 |
+
inception.Mixed_6d,
|
109 |
+
inception.Mixed_6e,
|
110 |
+
]
|
111 |
+
self.blocks.append(nn.Sequential(*block2))
|
112 |
+
|
113 |
+
# Block 3: aux classifier to final avgpool
|
114 |
+
if self.last_needed_block >= 3:
|
115 |
+
block3 = [
|
116 |
+
inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
|
117 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
118 |
+
]
|
119 |
+
self.blocks.append(nn.Sequential(*block3))
|
120 |
+
|
121 |
+
for param in self.parameters():
|
122 |
+
param.requires_grad = requires_grad
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
"""Get Inception feature maps.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
x (Tensor): Input tensor of shape (b, 3, h, w).
|
129 |
+
Values are expected to be in range (-1, 1). You can also input
|
130 |
+
(0, 1) with setting normalize_input = True.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
list[Tensor]: Corresponding to the selected output block, sorted
|
134 |
+
ascending by index.
|
135 |
+
"""
|
136 |
+
output = []
|
137 |
+
|
138 |
+
if self.resize_input:
|
139 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
|
140 |
+
|
141 |
+
if self.normalize_input:
|
142 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
143 |
+
|
144 |
+
for idx, block in enumerate(self.blocks):
|
145 |
+
x = block(x)
|
146 |
+
if idx in self.output_blocks:
|
147 |
+
output.append(x)
|
148 |
+
|
149 |
+
if idx == self.last_needed_block:
|
150 |
+
break
|
151 |
+
|
152 |
+
return output
|
153 |
+
|
154 |
+
|
155 |
+
def fid_inception_v3():
|
156 |
+
"""Build pretrained Inception model for FID computation.
|
157 |
+
|
158 |
+
The Inception model for FID computation uses a different set of weights
|
159 |
+
and has a slightly different structure than torchvision's Inception.
|
160 |
+
|
161 |
+
This method first constructs torchvision's Inception and then patches the
|
162 |
+
necessary parts that are different in the FID Inception model.
|
163 |
+
"""
|
164 |
+
try:
|
165 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
|
166 |
+
except TypeError:
|
167 |
+
# pytorch < 1.5 does not have init_weights for inception_v3
|
168 |
+
inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
|
169 |
+
|
170 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
171 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
172 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
173 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
174 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
175 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
176 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
177 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
178 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
179 |
+
|
180 |
+
if os.path.exists(LOCAL_FID_WEIGHTS):
|
181 |
+
state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
|
182 |
+
else:
|
183 |
+
state_dict = load_url(FID_WEIGHTS_URL, progress=True)
|
184 |
+
|
185 |
+
inception.load_state_dict(state_dict)
|
186 |
+
return inception
|
187 |
+
|
188 |
+
|
189 |
+
class FIDInceptionA(models.inception.InceptionA):
|
190 |
+
"""InceptionA block patched for FID computation"""
|
191 |
+
|
192 |
+
def __init__(self, in_channels, pool_features):
|
193 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
194 |
+
|
195 |
+
def forward(self, x):
|
196 |
+
branch1x1 = self.branch1x1(x)
|
197 |
+
|
198 |
+
branch5x5 = self.branch5x5_1(x)
|
199 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
200 |
+
|
201 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
202 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
203 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
204 |
+
|
205 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
206 |
+
# its average calculation
|
207 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
208 |
+
branch_pool = self.branch_pool(branch_pool)
|
209 |
+
|
210 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
211 |
+
return torch.cat(outputs, 1)
|
212 |
+
|
213 |
+
|
214 |
+
class FIDInceptionC(models.inception.InceptionC):
|
215 |
+
"""InceptionC block patched for FID computation"""
|
216 |
+
|
217 |
+
def __init__(self, in_channels, channels_7x7):
|
218 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
branch1x1 = self.branch1x1(x)
|
222 |
+
|
223 |
+
branch7x7 = self.branch7x7_1(x)
|
224 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
225 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
226 |
+
|
227 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
228 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
229 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
230 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
231 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
232 |
+
|
233 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
234 |
+
# its average calculation
|
235 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
236 |
+
branch_pool = self.branch_pool(branch_pool)
|
237 |
+
|
238 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
239 |
+
return torch.cat(outputs, 1)
|
240 |
+
|
241 |
+
|
242 |
+
class FIDInceptionE_1(models.inception.InceptionE):
|
243 |
+
"""First InceptionE block patched for FID computation"""
|
244 |
+
|
245 |
+
def __init__(self, in_channels):
|
246 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
branch1x1 = self.branch1x1(x)
|
250 |
+
|
251 |
+
branch3x3 = self.branch3x3_1(x)
|
252 |
+
branch3x3 = [
|
253 |
+
self.branch3x3_2a(branch3x3),
|
254 |
+
self.branch3x3_2b(branch3x3),
|
255 |
+
]
|
256 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
257 |
+
|
258 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
259 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
260 |
+
branch3x3dbl = [
|
261 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
262 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
263 |
+
]
|
264 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
265 |
+
|
266 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
267 |
+
# its average calculation
|
268 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
|
269 |
+
branch_pool = self.branch_pool(branch_pool)
|
270 |
+
|
271 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
272 |
+
return torch.cat(outputs, 1)
|
273 |
+
|
274 |
+
|
275 |
+
class FIDInceptionE_2(models.inception.InceptionE):
|
276 |
+
"""Second InceptionE block patched for FID computation"""
|
277 |
+
|
278 |
+
def __init__(self, in_channels):
|
279 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
280 |
+
|
281 |
+
def forward(self, x):
|
282 |
+
branch1x1 = self.branch1x1(x)
|
283 |
+
|
284 |
+
branch3x3 = self.branch3x3_1(x)
|
285 |
+
branch3x3 = [
|
286 |
+
self.branch3x3_2a(branch3x3),
|
287 |
+
self.branch3x3_2b(branch3x3),
|
288 |
+
]
|
289 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
290 |
+
|
291 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
292 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
293 |
+
branch3x3dbl = [
|
294 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
295 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
296 |
+
]
|
297 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
298 |
+
|
299 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
300 |
+
# pooling. This is likely an error in this specific Inception
|
301 |
+
# implementation, as other Inception models use average pooling here
|
302 |
+
# (which matches the description in the paper).
|
303 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
304 |
+
branch_pool = self.branch_pool(branch_pool)
|
305 |
+
|
306 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
307 |
+
return torch.cat(outputs, 1)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
|
4 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import Upsample, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class ChannelAttention(nn.Module):
|
9 |
+
"""Channel attention used in RCAN.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
num_feat (int): Channel number of intermediate features.
|
13 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, num_feat, squeeze_factor=16):
|
17 |
+
super(ChannelAttention, self).__init__()
|
18 |
+
self.attention = nn.Sequential(
|
19 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
|
20 |
+
nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
y = self.attention(x)
|
24 |
+
return x * y
|
25 |
+
|
26 |
+
|
27 |
+
class RCAB(nn.Module):
|
28 |
+
"""Residual Channel Attention Block (RCAB) used in RCAN.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
num_feat (int): Channel number of intermediate features.
|
32 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
33 |
+
res_scale (float): Scale the residual. Default: 1.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
|
37 |
+
super(RCAB, self).__init__()
|
38 |
+
self.res_scale = res_scale
|
39 |
+
|
40 |
+
self.rcab = nn.Sequential(
|
41 |
+
nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
|
42 |
+
ChannelAttention(num_feat, squeeze_factor))
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
res = self.rcab(x) * self.res_scale
|
46 |
+
return res + x
|
47 |
+
|
48 |
+
|
49 |
+
class ResidualGroup(nn.Module):
|
50 |
+
"""Residual Group of RCAB.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
num_feat (int): Channel number of intermediate features.
|
54 |
+
num_block (int): Block number in the body network.
|
55 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
56 |
+
res_scale (float): Scale the residual. Default: 1.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
|
60 |
+
super(ResidualGroup, self).__init__()
|
61 |
+
|
62 |
+
self.residual_group = make_layer(
|
63 |
+
RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
|
64 |
+
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
res = self.conv(self.residual_group(x))
|
68 |
+
return res + x
|
69 |
+
|
70 |
+
|
71 |
+
@ARCH_REGISTRY.register()
|
72 |
+
class RCAN(nn.Module):
|
73 |
+
"""Residual Channel Attention Networks.
|
74 |
+
|
75 |
+
Paper: Image Super-Resolution Using Very Deep Residual Channel Attention
|
76 |
+
Networks
|
77 |
+
Ref git repo: https://github.com/yulunzhang/RCAN.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
num_in_ch (int): Channel number of inputs.
|
81 |
+
num_out_ch (int): Channel number of outputs.
|
82 |
+
num_feat (int): Channel number of intermediate features.
|
83 |
+
Default: 64.
|
84 |
+
num_group (int): Number of ResidualGroup. Default: 10.
|
85 |
+
num_block (int): Number of RCAB in ResidualGroup. Default: 16.
|
86 |
+
squeeze_factor (int): Channel squeeze factor. Default: 16.
|
87 |
+
upscale (int): Upsampling factor. Support 2^n and 3.
|
88 |
+
Default: 4.
|
89 |
+
res_scale (float): Used to scale the residual in residual block.
|
90 |
+
Default: 1.
|
91 |
+
img_range (float): Image range. Default: 255.
|
92 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
93 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(self,
|
97 |
+
num_in_ch,
|
98 |
+
num_out_ch,
|
99 |
+
num_feat=64,
|
100 |
+
num_group=10,
|
101 |
+
num_block=16,
|
102 |
+
squeeze_factor=16,
|
103 |
+
upscale=4,
|
104 |
+
res_scale=1,
|
105 |
+
img_range=255.,
|
106 |
+
rgb_mean=(0.4488, 0.4371, 0.4040)):
|
107 |
+
super(RCAN, self).__init__()
|
108 |
+
|
109 |
+
self.img_range = img_range
|
110 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
111 |
+
|
112 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
113 |
+
self.body = make_layer(
|
114 |
+
ResidualGroup,
|
115 |
+
num_group,
|
116 |
+
num_feat=num_feat,
|
117 |
+
num_block=num_block,
|
118 |
+
squeeze_factor=squeeze_factor,
|
119 |
+
res_scale=res_scale)
|
120 |
+
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
121 |
+
self.upsample = Upsample(upscale, num_feat)
|
122 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
self.mean = self.mean.type_as(x)
|
126 |
+
|
127 |
+
x = (x - self.mean) * self.img_range
|
128 |
+
x = self.conv_first(x)
|
129 |
+
res = self.conv_after_body(self.body(x))
|
130 |
+
res += x
|
131 |
+
|
132 |
+
x = self.conv_last(self.upsample(res))
|
133 |
+
x = x / self.img_range + self.mean
|
134 |
+
|
135 |
+
return x
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
class MeanShift(nn.Conv2d):
|
9 |
+
""" Data normalization with mean and std.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
rgb_range (int): Maximum value of RGB.
|
13 |
+
rgb_mean (list[float]): Mean for RGB channels.
|
14 |
+
rgb_std (list[float]): Std for RGB channels.
|
15 |
+
sign (int): For subtraction, sign is -1, for addition, sign is 1.
|
16 |
+
Default: -1.
|
17 |
+
requires_grad (bool): Whether to update the self.weight and self.bias.
|
18 |
+
Default: True.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
|
22 |
+
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
23 |
+
std = torch.Tensor(rgb_std)
|
24 |
+
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
|
25 |
+
self.weight.data.div_(std.view(3, 1, 1, 1))
|
26 |
+
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
|
27 |
+
self.bias.data.div_(std)
|
28 |
+
self.requires_grad = requires_grad
|
29 |
+
|
30 |
+
|
31 |
+
class EResidualBlockNoBN(nn.Module):
|
32 |
+
"""Enhanced Residual block without BN.
|
33 |
+
|
34 |
+
There are three convolution layers in residual branch.
|
35 |
+
|
36 |
+
It has a style of:
|
37 |
+
---Conv-ReLU-Conv-ReLU-Conv-+-ReLU-
|
38 |
+
|__________________________|
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, in_channels, out_channels):
|
42 |
+
super(EResidualBlockNoBN, self).__init__()
|
43 |
+
|
44 |
+
self.body = nn.Sequential(
|
45 |
+
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
48 |
+
nn.ReLU(inplace=True),
|
49 |
+
nn.Conv2d(out_channels, out_channels, 1, 1, 0),
|
50 |
+
)
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
out = self.body(x)
|
55 |
+
out = self.relu(out + x)
|
56 |
+
return out
|
57 |
+
|
58 |
+
|
59 |
+
class MergeRun(nn.Module):
|
60 |
+
""" Merge-and-run unit.
|
61 |
+
|
62 |
+
This unit contains two branches with different dilated convolutions,
|
63 |
+
followed by a convolution to process the concatenated features.
|
64 |
+
|
65 |
+
Paper: Real Image Denoising with Feature Attention
|
66 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
|
70 |
+
super(MergeRun, self).__init__()
|
71 |
+
|
72 |
+
self.dilation1 = nn.Sequential(
|
73 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
|
74 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
|
75 |
+
self.dilation2 = nn.Sequential(
|
76 |
+
nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
|
77 |
+
nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
|
78 |
+
|
79 |
+
self.aggregation = nn.Sequential(
|
80 |
+
nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
dilation1 = self.dilation1(x)
|
84 |
+
dilation2 = self.dilation2(x)
|
85 |
+
out = torch.cat([dilation1, dilation2], dim=1)
|
86 |
+
out = self.aggregation(out)
|
87 |
+
out = out + x
|
88 |
+
return out
|
89 |
+
|
90 |
+
|
91 |
+
class ChannelAttention(nn.Module):
|
92 |
+
"""Channel attention.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
num_feat (int): Channel number of intermediate features.
|
96 |
+
squeeze_factor (int): Channel squeeze factor. Default:
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, mid_channels, squeeze_factor=16):
|
100 |
+
super(ChannelAttention, self).__init__()
|
101 |
+
self.attention = nn.Sequential(
|
102 |
+
nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
|
103 |
+
nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
y = self.attention(x)
|
107 |
+
return x * y
|
108 |
+
|
109 |
+
|
110 |
+
class EAM(nn.Module):
|
111 |
+
"""Enhancement attention modules (EAM) in RIDNet.
|
112 |
+
|
113 |
+
This module contains a merge-and-run unit, a residual block,
|
114 |
+
an enhanced residual block and a feature attention unit.
|
115 |
+
|
116 |
+
Attributes:
|
117 |
+
merge: The merge-and-run unit.
|
118 |
+
block1: The residual block.
|
119 |
+
block2: The enhanced residual block.
|
120 |
+
ca: The feature/channel attention unit.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __init__(self, in_channels, mid_channels, out_channels):
|
124 |
+
super(EAM, self).__init__()
|
125 |
+
|
126 |
+
self.merge = MergeRun(in_channels, mid_channels)
|
127 |
+
self.block1 = ResidualBlockNoBN(mid_channels)
|
128 |
+
self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
|
129 |
+
self.ca = ChannelAttention(out_channels)
|
130 |
+
# The residual block in the paper contains a relu after addition.
|
131 |
+
self.relu = nn.ReLU(inplace=True)
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
out = self.merge(x)
|
135 |
+
out = self.relu(self.block1(out))
|
136 |
+
out = self.block2(out)
|
137 |
+
out = self.ca(out)
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
@ARCH_REGISTRY.register()
|
142 |
+
class RIDNet(nn.Module):
|
143 |
+
"""RIDNet: Real Image Denoising with Feature Attention.
|
144 |
+
|
145 |
+
Ref git repo: https://github.com/saeed-anwar/RIDNet
|
146 |
+
|
147 |
+
Args:
|
148 |
+
in_channels (int): Channel number of inputs.
|
149 |
+
mid_channels (int): Channel number of EAM modules.
|
150 |
+
Default: 64.
|
151 |
+
out_channels (int): Channel number of outputs.
|
152 |
+
num_block (int): Number of EAM. Default: 4.
|
153 |
+
img_range (float): Image range. Default: 255.
|
154 |
+
rgb_mean (tuple[float]): Image mean in RGB orders.
|
155 |
+
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
|
156 |
+
"""
|
157 |
+
|
158 |
+
def __init__(self,
|
159 |
+
in_channels,
|
160 |
+
mid_channels,
|
161 |
+
out_channels,
|
162 |
+
num_block=4,
|
163 |
+
img_range=255.,
|
164 |
+
rgb_mean=(0.4488, 0.4371, 0.4040),
|
165 |
+
rgb_std=(1.0, 1.0, 1.0)):
|
166 |
+
super(RIDNet, self).__init__()
|
167 |
+
|
168 |
+
self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
|
169 |
+
self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
|
170 |
+
|
171 |
+
self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
|
172 |
+
self.body = make_layer(
|
173 |
+
EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
|
174 |
+
self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
|
175 |
+
|
176 |
+
self.relu = nn.ReLU(inplace=True)
|
177 |
+
|
178 |
+
def forward(self, x):
|
179 |
+
res = self.sub_mean(x)
|
180 |
+
res = self.tail(self.body(self.relu(self.head(res))))
|
181 |
+
res = self.add_mean(res)
|
182 |
+
|
183 |
+
out = x + res
|
184 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
|
7 |
+
|
8 |
+
|
9 |
+
class ResidualDenseBlock(nn.Module):
|
10 |
+
"""Residual Dense Block.
|
11 |
+
|
12 |
+
Used in RRDB block in ESRGAN.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_feat (int): Channel number of intermediate features.
|
16 |
+
num_grow_ch (int): Channels for each growth.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_feat=64, num_grow_ch=32):
|
20 |
+
super(ResidualDenseBlock, self).__init__()
|
21 |
+
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
|
22 |
+
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
|
23 |
+
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
24 |
+
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
|
25 |
+
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
|
26 |
+
|
27 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
28 |
+
|
29 |
+
# initialization
|
30 |
+
default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x1 = self.lrelu(self.conv1(x))
|
34 |
+
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
|
35 |
+
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
|
36 |
+
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
|
37 |
+
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
|
38 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
39 |
+
return x5 * 0.2 + x
|
40 |
+
|
41 |
+
|
42 |
+
class RRDB(nn.Module):
|
43 |
+
"""Residual in Residual Dense Block.
|
44 |
+
|
45 |
+
Used in RRDB-Net in ESRGAN.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
num_feat (int): Channel number of intermediate features.
|
49 |
+
num_grow_ch (int): Channels for each growth.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, num_feat, num_grow_ch=32):
|
53 |
+
super(RRDB, self).__init__()
|
54 |
+
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
|
55 |
+
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
|
56 |
+
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
|
57 |
+
|
58 |
+
def forward(self, x):
|
59 |
+
out = self.rdb1(x)
|
60 |
+
out = self.rdb2(out)
|
61 |
+
out = self.rdb3(out)
|
62 |
+
# Empirically, we use 0.2 to scale the residual for better performance
|
63 |
+
return out * 0.2 + x
|
64 |
+
|
65 |
+
|
66 |
+
@ARCH_REGISTRY.register()
|
67 |
+
class RRDBNet(nn.Module):
|
68 |
+
"""Networks consisting of Residual in Residual Dense Block, which is used
|
69 |
+
in ESRGAN.
|
70 |
+
|
71 |
+
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
|
72 |
+
|
73 |
+
We extend ESRGAN for scale x2 and scale x1.
|
74 |
+
Note: This is one option for scale 1, scale 2 in RRDBNet.
|
75 |
+
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
|
76 |
+
and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
num_in_ch (int): Channel number of inputs.
|
80 |
+
num_out_ch (int): Channel number of outputs.
|
81 |
+
num_feat (int): Channel number of intermediate features.
|
82 |
+
Default: 64
|
83 |
+
num_block (int): Block number in the trunk network. Defaults: 23
|
84 |
+
num_grow_ch (int): Channels for each growth. Default: 32.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
|
88 |
+
super(RRDBNet, self).__init__()
|
89 |
+
self.scale = scale
|
90 |
+
if scale == 2:
|
91 |
+
num_in_ch = num_in_ch * 4
|
92 |
+
elif scale == 1:
|
93 |
+
num_in_ch = num_in_ch * 16
|
94 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
95 |
+
self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
|
96 |
+
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
97 |
+
# upsample
|
98 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
99 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
100 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
101 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
102 |
+
|
103 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
104 |
+
|
105 |
+
def forward(self, x):
|
106 |
+
if self.scale == 2:
|
107 |
+
feat = pixel_unshuffle(x, scale=2)
|
108 |
+
elif self.scale == 1:
|
109 |
+
feat = pixel_unshuffle(x, scale=4)
|
110 |
+
else:
|
111 |
+
feat = x
|
112 |
+
feat = self.conv_first(feat)
|
113 |
+
body_feat = self.conv_body(self.body(feat))
|
114 |
+
feat = feat + body_feat
|
115 |
+
# upsample
|
116 |
+
feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
117 |
+
feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
|
118 |
+
out = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
119 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
7 |
+
from .arch_util import flow_warp
|
8 |
+
|
9 |
+
|
10 |
+
class BasicModule(nn.Module):
|
11 |
+
"""Basic Module for SpyNet.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
super(BasicModule, self).__init__()
|
16 |
+
|
17 |
+
self.basic_module = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
19 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
20 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
21 |
+
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
|
22 |
+
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
23 |
+
|
24 |
+
def forward(self, tensor_input):
|
25 |
+
return self.basic_module(tensor_input)
|
26 |
+
|
27 |
+
|
28 |
+
@ARCH_REGISTRY.register()
|
29 |
+
class SpyNet(nn.Module):
|
30 |
+
"""SpyNet architecture.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
load_path (str): path for pretrained SpyNet. Default: None.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, load_path=None):
|
37 |
+
super(SpyNet, self).__init__()
|
38 |
+
self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
|
39 |
+
if load_path:
|
40 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
41 |
+
|
42 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
43 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
44 |
+
|
45 |
+
def preprocess(self, tensor_input):
|
46 |
+
tensor_output = (tensor_input - self.mean) / self.std
|
47 |
+
return tensor_output
|
48 |
+
|
49 |
+
def process(self, ref, supp):
|
50 |
+
flow = []
|
51 |
+
|
52 |
+
ref = [self.preprocess(ref)]
|
53 |
+
supp = [self.preprocess(supp)]
|
54 |
+
|
55 |
+
for level in range(5):
|
56 |
+
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
|
57 |
+
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
|
58 |
+
|
59 |
+
flow = ref[0].new_zeros(
|
60 |
+
[ref[0].size(0), 2,
|
61 |
+
int(math.floor(ref[0].size(2) / 2.0)),
|
62 |
+
int(math.floor(ref[0].size(3) / 2.0))])
|
63 |
+
|
64 |
+
for level in range(len(ref)):
|
65 |
+
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
66 |
+
|
67 |
+
if upsampled_flow.size(2) != ref[level].size(2):
|
68 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
|
69 |
+
if upsampled_flow.size(3) != ref[level].size(3):
|
70 |
+
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
|
71 |
+
|
72 |
+
flow = self.basic_module[level](torch.cat([
|
73 |
+
ref[level],
|
74 |
+
flow_warp(
|
75 |
+
supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
|
76 |
+
upsampled_flow
|
77 |
+
], 1)) + upsampled_flow
|
78 |
+
|
79 |
+
return flow
|
80 |
+
|
81 |
+
def forward(self, ref, supp):
|
82 |
+
assert ref.size() == supp.size()
|
83 |
+
|
84 |
+
h, w = ref.size(2), ref.size(3)
|
85 |
+
w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
|
86 |
+
h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
|
87 |
+
|
88 |
+
ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
89 |
+
supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
|
90 |
+
|
91 |
+
flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
|
92 |
+
|
93 |
+
flow[:, 0, :, :] *= float(w) / float(w_floor)
|
94 |
+
flow[:, 1, :, :] *= float(h) / float(h_floor)
|
95 |
+
|
96 |
+
return flow
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
|
6 |
+
|
7 |
+
|
8 |
+
@ARCH_REGISTRY.register()
|
9 |
+
class MSRResNet(nn.Module):
|
10 |
+
"""Modified SRResNet.
|
11 |
+
|
12 |
+
A compacted version modified from SRResNet in
|
13 |
+
"Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
|
14 |
+
It uses residual blocks without BN, similar to EDSR.
|
15 |
+
Currently, it supports x2, x3 and x4 upsampling scale factor.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
19 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
20 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
21 |
+
num_block (int): Block number in the body network. Default: 16.
|
22 |
+
upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
|
26 |
+
super(MSRResNet, self).__init__()
|
27 |
+
self.upscale = upscale
|
28 |
+
|
29 |
+
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
|
30 |
+
self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
|
31 |
+
|
32 |
+
# upsampling
|
33 |
+
if self.upscale in [2, 3]:
|
34 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
|
35 |
+
self.pixel_shuffle = nn.PixelShuffle(self.upscale)
|
36 |
+
elif self.upscale == 4:
|
37 |
+
self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
38 |
+
self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
|
39 |
+
self.pixel_shuffle = nn.PixelShuffle(2)
|
40 |
+
|
41 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
42 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
43 |
+
|
44 |
+
# activation function
|
45 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
46 |
+
|
47 |
+
# initialization
|
48 |
+
default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
|
49 |
+
if self.upscale == 4:
|
50 |
+
default_init_weights(self.upconv2, 0.1)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
feat = self.lrelu(self.conv_first(x))
|
54 |
+
out = self.body(feat)
|
55 |
+
|
56 |
+
if self.upscale == 4:
|
57 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
58 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
|
59 |
+
elif self.upscale in [2, 3]:
|
60 |
+
out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
|
61 |
+
|
62 |
+
out = self.conv_last(self.lrelu(self.conv_hr(out)))
|
63 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
|
64 |
+
out += base
|
65 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn as nn
|
2 |
+
from torch.nn import functional as F
|
3 |
+
|
4 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
5 |
+
|
6 |
+
|
7 |
+
@ARCH_REGISTRY.register(suffix='basicsr')
|
8 |
+
class SRVGGNetCompact(nn.Module):
|
9 |
+
"""A compact VGG-style network structure for super-resolution.
|
10 |
+
|
11 |
+
It is a compact network structure, which performs upsampling in the last layer and no convolution is
|
12 |
+
conducted on the HR feature space.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
num_in_ch (int): Channel number of inputs. Default: 3.
|
16 |
+
num_out_ch (int): Channel number of outputs. Default: 3.
|
17 |
+
num_feat (int): Channel number of intermediate features. Default: 64.
|
18 |
+
num_conv (int): Number of convolution layers in the body network. Default: 16.
|
19 |
+
upscale (int): Upsampling factor. Default: 4.
|
20 |
+
act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
|
24 |
+
super(SRVGGNetCompact, self).__init__()
|
25 |
+
self.num_in_ch = num_in_ch
|
26 |
+
self.num_out_ch = num_out_ch
|
27 |
+
self.num_feat = num_feat
|
28 |
+
self.num_conv = num_conv
|
29 |
+
self.upscale = upscale
|
30 |
+
self.act_type = act_type
|
31 |
+
|
32 |
+
self.body = nn.ModuleList()
|
33 |
+
# the first conv
|
34 |
+
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
|
35 |
+
# the first activation
|
36 |
+
if act_type == 'relu':
|
37 |
+
activation = nn.ReLU(inplace=True)
|
38 |
+
elif act_type == 'prelu':
|
39 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
40 |
+
elif act_type == 'leakyrelu':
|
41 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
42 |
+
self.body.append(activation)
|
43 |
+
|
44 |
+
# the body structure
|
45 |
+
for _ in range(num_conv):
|
46 |
+
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
|
47 |
+
# activation
|
48 |
+
if act_type == 'relu':
|
49 |
+
activation = nn.ReLU(inplace=True)
|
50 |
+
elif act_type == 'prelu':
|
51 |
+
activation = nn.PReLU(num_parameters=num_feat)
|
52 |
+
elif act_type == 'leakyrelu':
|
53 |
+
activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
54 |
+
self.body.append(activation)
|
55 |
+
|
56 |
+
# the last conv
|
57 |
+
self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
|
58 |
+
# upsample
|
59 |
+
self.upsampler = nn.PixelShuffle(upscale)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
out = x
|
63 |
+
for i in range(0, len(self.body)):
|
64 |
+
out = self.body[i](out)
|
65 |
+
|
66 |
+
out = self.upsampler(out)
|
67 |
+
# add the nearest upsampled image, so that the network learns the residual
|
68 |
+
base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
|
69 |
+
out += base
|
70 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py
ADDED
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from r_basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
|
8 |
+
from r_basicsr.ops.upfirdn2d import upfirdn2d
|
9 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
class NormStyleCode(nn.Module):
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
"""Normalize the style codes.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
x (Tensor): Style codes with shape (b, c).
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: Normalized tensor.
|
22 |
+
"""
|
23 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
|
24 |
+
|
25 |
+
|
26 |
+
def make_resample_kernel(k):
|
27 |
+
"""Make resampling kernel for UpFirDn.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
k (list[int]): A list indicating the 1D resample kernel magnitude.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tensor: 2D resampled kernel.
|
34 |
+
"""
|
35 |
+
k = torch.tensor(k, dtype=torch.float32)
|
36 |
+
if k.ndim == 1:
|
37 |
+
k = k[None, :] * k[:, None] # to 2D kernel, outer product
|
38 |
+
# normalize
|
39 |
+
k /= k.sum()
|
40 |
+
return k
|
41 |
+
|
42 |
+
|
43 |
+
class UpFirDnUpsample(nn.Module):
|
44 |
+
"""Upsample, FIR filter, and downsample (upsampole version).
|
45 |
+
|
46 |
+
References:
|
47 |
+
1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
|
48 |
+
2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
|
49 |
+
|
50 |
+
Args:
|
51 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
52 |
+
magnitude.
|
53 |
+
factor (int): Upsampling scale factor. Default: 2.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, resample_kernel, factor=2):
|
57 |
+
super(UpFirDnUpsample, self).__init__()
|
58 |
+
self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
|
59 |
+
self.factor = factor
|
60 |
+
|
61 |
+
pad = self.kernel.shape[0] - factor
|
62 |
+
self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def __repr__(self):
|
69 |
+
return (f'{self.__class__.__name__}(factor={self.factor})')
|
70 |
+
|
71 |
+
|
72 |
+
class UpFirDnDownsample(nn.Module):
|
73 |
+
"""Upsample, FIR filter, and downsample (downsampole version).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
77 |
+
magnitude.
|
78 |
+
factor (int): Downsampling scale factor. Default: 2.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, resample_kernel, factor=2):
|
82 |
+
super(UpFirDnDownsample, self).__init__()
|
83 |
+
self.kernel = make_resample_kernel(resample_kernel)
|
84 |
+
self.factor = factor
|
85 |
+
|
86 |
+
pad = self.kernel.shape[0] - factor
|
87 |
+
self.pad = ((pad + 1) // 2, pad // 2)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
|
91 |
+
return out
|
92 |
+
|
93 |
+
def __repr__(self):
|
94 |
+
return (f'{self.__class__.__name__}(factor={self.factor})')
|
95 |
+
|
96 |
+
|
97 |
+
class UpFirDnSmooth(nn.Module):
|
98 |
+
"""Upsample, FIR filter, and downsample (smooth version).
|
99 |
+
|
100 |
+
Args:
|
101 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
102 |
+
magnitude.
|
103 |
+
upsample_factor (int): Upsampling scale factor. Default: 1.
|
104 |
+
downsample_factor (int): Downsampling scale factor. Default: 1.
|
105 |
+
kernel_size (int): Kernel size: Default: 1.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
|
109 |
+
super(UpFirDnSmooth, self).__init__()
|
110 |
+
self.upsample_factor = upsample_factor
|
111 |
+
self.downsample_factor = downsample_factor
|
112 |
+
self.kernel = make_resample_kernel(resample_kernel)
|
113 |
+
if upsample_factor > 1:
|
114 |
+
self.kernel = self.kernel * (upsample_factor**2)
|
115 |
+
|
116 |
+
if upsample_factor > 1:
|
117 |
+
pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
|
118 |
+
self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
|
119 |
+
elif downsample_factor > 1:
|
120 |
+
pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
|
121 |
+
self.pad = ((pad + 1) // 2, pad // 2)
|
122 |
+
else:
|
123 |
+
raise NotImplementedError
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
|
127 |
+
return out
|
128 |
+
|
129 |
+
def __repr__(self):
|
130 |
+
return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
|
131 |
+
f', downsample_factor={self.downsample_factor})')
|
132 |
+
|
133 |
+
|
134 |
+
class EqualLinear(nn.Module):
|
135 |
+
"""Equalized Linear as StyleGAN2.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
in_channels (int): Size of each sample.
|
139 |
+
out_channels (int): Size of each output sample.
|
140 |
+
bias (bool): If set to ``False``, the layer will not learn an additive
|
141 |
+
bias. Default: ``True``.
|
142 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
143 |
+
lr_mul (float): Learning rate multiplier. Default: 1.
|
144 |
+
activation (None | str): The activation after ``linear`` operation.
|
145 |
+
Supported: 'fused_lrelu', None. Default: None.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
|
149 |
+
super(EqualLinear, self).__init__()
|
150 |
+
self.in_channels = in_channels
|
151 |
+
self.out_channels = out_channels
|
152 |
+
self.lr_mul = lr_mul
|
153 |
+
self.activation = activation
|
154 |
+
if self.activation not in ['fused_lrelu', None]:
|
155 |
+
raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
|
156 |
+
"Supported ones are: ['fused_lrelu', None].")
|
157 |
+
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
|
158 |
+
|
159 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
|
160 |
+
if bias:
|
161 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
162 |
+
else:
|
163 |
+
self.register_parameter('bias', None)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
if self.bias is None:
|
167 |
+
bias = None
|
168 |
+
else:
|
169 |
+
bias = self.bias * self.lr_mul
|
170 |
+
if self.activation == 'fused_lrelu':
|
171 |
+
out = F.linear(x, self.weight * self.scale)
|
172 |
+
out = fused_leaky_relu(out, bias)
|
173 |
+
else:
|
174 |
+
out = F.linear(x, self.weight * self.scale, bias=bias)
|
175 |
+
return out
|
176 |
+
|
177 |
+
def __repr__(self):
|
178 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
179 |
+
f'out_channels={self.out_channels}, bias={self.bias is not None})')
|
180 |
+
|
181 |
+
|
182 |
+
class ModulatedConv2d(nn.Module):
|
183 |
+
"""Modulated Conv2d used in StyleGAN2.
|
184 |
+
|
185 |
+
There is no bias in ModulatedConv2d.
|
186 |
+
|
187 |
+
Args:
|
188 |
+
in_channels (int): Channel number of the input.
|
189 |
+
out_channels (int): Channel number of the output.
|
190 |
+
kernel_size (int): Size of the convolving kernel.
|
191 |
+
num_style_feat (int): Channel number of style features.
|
192 |
+
demodulate (bool): Whether to demodulate in the conv layer.
|
193 |
+
Default: True.
|
194 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
195 |
+
Default: None.
|
196 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
197 |
+
magnitude. Default: (1, 3, 3, 1).
|
198 |
+
eps (float): A value added to the denominator for numerical stability.
|
199 |
+
Default: 1e-8.
|
200 |
+
"""
|
201 |
+
|
202 |
+
def __init__(self,
|
203 |
+
in_channels,
|
204 |
+
out_channels,
|
205 |
+
kernel_size,
|
206 |
+
num_style_feat,
|
207 |
+
demodulate=True,
|
208 |
+
sample_mode=None,
|
209 |
+
resample_kernel=(1, 3, 3, 1),
|
210 |
+
eps=1e-8):
|
211 |
+
super(ModulatedConv2d, self).__init__()
|
212 |
+
self.in_channels = in_channels
|
213 |
+
self.out_channels = out_channels
|
214 |
+
self.kernel_size = kernel_size
|
215 |
+
self.demodulate = demodulate
|
216 |
+
self.sample_mode = sample_mode
|
217 |
+
self.eps = eps
|
218 |
+
|
219 |
+
if self.sample_mode == 'upsample':
|
220 |
+
self.smooth = UpFirDnSmooth(
|
221 |
+
resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
|
222 |
+
elif self.sample_mode == 'downsample':
|
223 |
+
self.smooth = UpFirDnSmooth(
|
224 |
+
resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
|
225 |
+
elif self.sample_mode is None:
|
226 |
+
pass
|
227 |
+
else:
|
228 |
+
raise ValueError(f'Wrong sample mode {self.sample_mode}, '
|
229 |
+
"supported ones are ['upsample', 'downsample', None].")
|
230 |
+
|
231 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
232 |
+
# modulation inside each modulated conv
|
233 |
+
self.modulation = EqualLinear(
|
234 |
+
num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
|
235 |
+
|
236 |
+
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
|
237 |
+
self.padding = kernel_size // 2
|
238 |
+
|
239 |
+
def forward(self, x, style):
|
240 |
+
"""Forward function.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
x (Tensor): Tensor with shape (b, c, h, w).
|
244 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Tensor: Modulated tensor after convolution.
|
248 |
+
"""
|
249 |
+
b, c, h, w = x.shape # c = c_in
|
250 |
+
# weight modulation
|
251 |
+
style = self.modulation(style).view(b, 1, c, 1, 1)
|
252 |
+
# self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
|
253 |
+
weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
|
254 |
+
|
255 |
+
if self.demodulate:
|
256 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
|
257 |
+
weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
|
258 |
+
|
259 |
+
weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
|
260 |
+
|
261 |
+
if self.sample_mode == 'upsample':
|
262 |
+
x = x.view(1, b * c, h, w)
|
263 |
+
weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
|
264 |
+
weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
|
265 |
+
out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
|
266 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
267 |
+
out = self.smooth(out)
|
268 |
+
elif self.sample_mode == 'downsample':
|
269 |
+
x = self.smooth(x)
|
270 |
+
x = x.view(1, b * c, *x.shape[2:4])
|
271 |
+
out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
|
272 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
273 |
+
else:
|
274 |
+
x = x.view(1, b * c, h, w)
|
275 |
+
# weight: (b*c_out, c_in, k, k), groups=b
|
276 |
+
out = F.conv2d(x, weight, padding=self.padding, groups=b)
|
277 |
+
out = out.view(b, self.out_channels, *out.shape[2:4])
|
278 |
+
|
279 |
+
return out
|
280 |
+
|
281 |
+
def __repr__(self):
|
282 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
283 |
+
f'out_channels={self.out_channels}, '
|
284 |
+
f'kernel_size={self.kernel_size}, '
|
285 |
+
f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
|
286 |
+
|
287 |
+
|
288 |
+
class StyleConv(nn.Module):
|
289 |
+
"""Style conv.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
in_channels (int): Channel number of the input.
|
293 |
+
out_channels (int): Channel number of the output.
|
294 |
+
kernel_size (int): Size of the convolving kernel.
|
295 |
+
num_style_feat (int): Channel number of style features.
|
296 |
+
demodulate (bool): Whether demodulate in the conv layer. Default: True.
|
297 |
+
sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
|
298 |
+
Default: None.
|
299 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
300 |
+
magnitude. Default: (1, 3, 3, 1).
|
301 |
+
"""
|
302 |
+
|
303 |
+
def __init__(self,
|
304 |
+
in_channels,
|
305 |
+
out_channels,
|
306 |
+
kernel_size,
|
307 |
+
num_style_feat,
|
308 |
+
demodulate=True,
|
309 |
+
sample_mode=None,
|
310 |
+
resample_kernel=(1, 3, 3, 1)):
|
311 |
+
super(StyleConv, self).__init__()
|
312 |
+
self.modulated_conv = ModulatedConv2d(
|
313 |
+
in_channels,
|
314 |
+
out_channels,
|
315 |
+
kernel_size,
|
316 |
+
num_style_feat,
|
317 |
+
demodulate=demodulate,
|
318 |
+
sample_mode=sample_mode,
|
319 |
+
resample_kernel=resample_kernel)
|
320 |
+
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
|
321 |
+
self.activate = FusedLeakyReLU(out_channels)
|
322 |
+
|
323 |
+
def forward(self, x, style, noise=None):
|
324 |
+
# modulate
|
325 |
+
out = self.modulated_conv(x, style)
|
326 |
+
# noise injection
|
327 |
+
if noise is None:
|
328 |
+
b, _, h, w = out.shape
|
329 |
+
noise = out.new_empty(b, 1, h, w).normal_()
|
330 |
+
out = out + self.weight * noise
|
331 |
+
# activation (with bias)
|
332 |
+
out = self.activate(out)
|
333 |
+
return out
|
334 |
+
|
335 |
+
|
336 |
+
class ToRGB(nn.Module):
|
337 |
+
"""To RGB from features.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
in_channels (int): Channel number of input.
|
341 |
+
num_style_feat (int): Channel number of style features.
|
342 |
+
upsample (bool): Whether to upsample. Default: True.
|
343 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
344 |
+
magnitude. Default: (1, 3, 3, 1).
|
345 |
+
"""
|
346 |
+
|
347 |
+
def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
|
348 |
+
super(ToRGB, self).__init__()
|
349 |
+
if upsample:
|
350 |
+
self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
|
351 |
+
else:
|
352 |
+
self.upsample = None
|
353 |
+
self.modulated_conv = ModulatedConv2d(
|
354 |
+
in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
|
355 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
356 |
+
|
357 |
+
def forward(self, x, style, skip=None):
|
358 |
+
"""Forward function.
|
359 |
+
|
360 |
+
Args:
|
361 |
+
x (Tensor): Feature tensor with shape (b, c, h, w).
|
362 |
+
style (Tensor): Tensor with shape (b, num_style_feat).
|
363 |
+
skip (Tensor): Base/skip tensor. Default: None.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Tensor: RGB images.
|
367 |
+
"""
|
368 |
+
out = self.modulated_conv(x, style)
|
369 |
+
out = out + self.bias
|
370 |
+
if skip is not None:
|
371 |
+
if self.upsample:
|
372 |
+
skip = self.upsample(skip)
|
373 |
+
out = out + skip
|
374 |
+
return out
|
375 |
+
|
376 |
+
|
377 |
+
class ConstantInput(nn.Module):
|
378 |
+
"""Constant input.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
num_channel (int): Channel number of constant input.
|
382 |
+
size (int): Spatial size of constant input.
|
383 |
+
"""
|
384 |
+
|
385 |
+
def __init__(self, num_channel, size):
|
386 |
+
super(ConstantInput, self).__init__()
|
387 |
+
self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
|
388 |
+
|
389 |
+
def forward(self, batch):
|
390 |
+
out = self.weight.repeat(batch, 1, 1, 1)
|
391 |
+
return out
|
392 |
+
|
393 |
+
|
394 |
+
@ARCH_REGISTRY.register()
|
395 |
+
class StyleGAN2Generator(nn.Module):
|
396 |
+
"""StyleGAN2 Generator.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
out_size (int): The spatial size of outputs.
|
400 |
+
num_style_feat (int): Channel number of style features. Default: 512.
|
401 |
+
num_mlp (int): Layer number of MLP style layers. Default: 8.
|
402 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
403 |
+
StyleGAN2. Default: 2.
|
404 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
405 |
+
magnitude. A cross production will be applied to extent 1D resample
|
406 |
+
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
407 |
+
lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
|
408 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
409 |
+
"""
|
410 |
+
|
411 |
+
def __init__(self,
|
412 |
+
out_size,
|
413 |
+
num_style_feat=512,
|
414 |
+
num_mlp=8,
|
415 |
+
channel_multiplier=2,
|
416 |
+
resample_kernel=(1, 3, 3, 1),
|
417 |
+
lr_mlp=0.01,
|
418 |
+
narrow=1):
|
419 |
+
super(StyleGAN2Generator, self).__init__()
|
420 |
+
# Style MLP layers
|
421 |
+
self.num_style_feat = num_style_feat
|
422 |
+
style_mlp_layers = [NormStyleCode()]
|
423 |
+
for i in range(num_mlp):
|
424 |
+
style_mlp_layers.append(
|
425 |
+
EqualLinear(
|
426 |
+
num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
|
427 |
+
activation='fused_lrelu'))
|
428 |
+
self.style_mlp = nn.Sequential(*style_mlp_layers)
|
429 |
+
|
430 |
+
channels = {
|
431 |
+
'4': int(512 * narrow),
|
432 |
+
'8': int(512 * narrow),
|
433 |
+
'16': int(512 * narrow),
|
434 |
+
'32': int(512 * narrow),
|
435 |
+
'64': int(256 * channel_multiplier * narrow),
|
436 |
+
'128': int(128 * channel_multiplier * narrow),
|
437 |
+
'256': int(64 * channel_multiplier * narrow),
|
438 |
+
'512': int(32 * channel_multiplier * narrow),
|
439 |
+
'1024': int(16 * channel_multiplier * narrow)
|
440 |
+
}
|
441 |
+
self.channels = channels
|
442 |
+
|
443 |
+
self.constant_input = ConstantInput(channels['4'], size=4)
|
444 |
+
self.style_conv1 = StyleConv(
|
445 |
+
channels['4'],
|
446 |
+
channels['4'],
|
447 |
+
kernel_size=3,
|
448 |
+
num_style_feat=num_style_feat,
|
449 |
+
demodulate=True,
|
450 |
+
sample_mode=None,
|
451 |
+
resample_kernel=resample_kernel)
|
452 |
+
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
|
453 |
+
|
454 |
+
self.log_size = int(math.log(out_size, 2))
|
455 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
456 |
+
self.num_latent = self.log_size * 2 - 2
|
457 |
+
|
458 |
+
self.style_convs = nn.ModuleList()
|
459 |
+
self.to_rgbs = nn.ModuleList()
|
460 |
+
self.noises = nn.Module()
|
461 |
+
|
462 |
+
in_channels = channels['4']
|
463 |
+
# noise
|
464 |
+
for layer_idx in range(self.num_layers):
|
465 |
+
resolution = 2**((layer_idx + 5) // 2)
|
466 |
+
shape = [1, 1, resolution, resolution]
|
467 |
+
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
|
468 |
+
# style convs and to_rgbs
|
469 |
+
for i in range(3, self.log_size + 1):
|
470 |
+
out_channels = channels[f'{2**i}']
|
471 |
+
self.style_convs.append(
|
472 |
+
StyleConv(
|
473 |
+
in_channels,
|
474 |
+
out_channels,
|
475 |
+
kernel_size=3,
|
476 |
+
num_style_feat=num_style_feat,
|
477 |
+
demodulate=True,
|
478 |
+
sample_mode='upsample',
|
479 |
+
resample_kernel=resample_kernel,
|
480 |
+
))
|
481 |
+
self.style_convs.append(
|
482 |
+
StyleConv(
|
483 |
+
out_channels,
|
484 |
+
out_channels,
|
485 |
+
kernel_size=3,
|
486 |
+
num_style_feat=num_style_feat,
|
487 |
+
demodulate=True,
|
488 |
+
sample_mode=None,
|
489 |
+
resample_kernel=resample_kernel))
|
490 |
+
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
|
491 |
+
in_channels = out_channels
|
492 |
+
|
493 |
+
def make_noise(self):
|
494 |
+
"""Make noise for noise injection."""
|
495 |
+
device = self.constant_input.weight.device
|
496 |
+
noises = [torch.randn(1, 1, 4, 4, device=device)]
|
497 |
+
|
498 |
+
for i in range(3, self.log_size + 1):
|
499 |
+
for _ in range(2):
|
500 |
+
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
|
501 |
+
|
502 |
+
return noises
|
503 |
+
|
504 |
+
def get_latent(self, x):
|
505 |
+
return self.style_mlp(x)
|
506 |
+
|
507 |
+
def mean_latent(self, num_latent):
|
508 |
+
latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
|
509 |
+
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
|
510 |
+
return latent
|
511 |
+
|
512 |
+
def forward(self,
|
513 |
+
styles,
|
514 |
+
input_is_latent=False,
|
515 |
+
noise=None,
|
516 |
+
randomize_noise=True,
|
517 |
+
truncation=1,
|
518 |
+
truncation_latent=None,
|
519 |
+
inject_index=None,
|
520 |
+
return_latents=False):
|
521 |
+
"""Forward function for StyleGAN2Generator.
|
522 |
+
|
523 |
+
Args:
|
524 |
+
styles (list[Tensor]): Sample codes of styles.
|
525 |
+
input_is_latent (bool): Whether input is latent style.
|
526 |
+
Default: False.
|
527 |
+
noise (Tensor | None): Input noise or None. Default: None.
|
528 |
+
randomize_noise (bool): Randomize noise, used when 'noise' is
|
529 |
+
False. Default: True.
|
530 |
+
truncation (float): TODO. Default: 1.
|
531 |
+
truncation_latent (Tensor | None): TODO. Default: None.
|
532 |
+
inject_index (int | None): The injection index for mixing noise.
|
533 |
+
Default: None.
|
534 |
+
return_latents (bool): Whether to return style latents.
|
535 |
+
Default: False.
|
536 |
+
"""
|
537 |
+
# style codes -> latents with Style MLP layer
|
538 |
+
if not input_is_latent:
|
539 |
+
styles = [self.style_mlp(s) for s in styles]
|
540 |
+
# noises
|
541 |
+
if noise is None:
|
542 |
+
if randomize_noise:
|
543 |
+
noise = [None] * self.num_layers # for each style conv layer
|
544 |
+
else: # use the stored noise
|
545 |
+
noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
|
546 |
+
# style truncation
|
547 |
+
if truncation < 1:
|
548 |
+
style_truncation = []
|
549 |
+
for style in styles:
|
550 |
+
style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
|
551 |
+
styles = style_truncation
|
552 |
+
# get style latent with injection
|
553 |
+
if len(styles) == 1:
|
554 |
+
inject_index = self.num_latent
|
555 |
+
|
556 |
+
if styles[0].ndim < 3:
|
557 |
+
# repeat latent code for all the layers
|
558 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
559 |
+
else: # used for encoder with different latent code for each layer
|
560 |
+
latent = styles[0]
|
561 |
+
elif len(styles) == 2: # mixing noises
|
562 |
+
if inject_index is None:
|
563 |
+
inject_index = random.randint(1, self.num_latent - 1)
|
564 |
+
latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
565 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
|
566 |
+
latent = torch.cat([latent1, latent2], 1)
|
567 |
+
|
568 |
+
# main generation
|
569 |
+
out = self.constant_input(latent.shape[0])
|
570 |
+
out = self.style_conv1(out, latent[:, 0], noise=noise[0])
|
571 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
572 |
+
|
573 |
+
i = 1
|
574 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
|
575 |
+
noise[2::2], self.to_rgbs):
|
576 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
577 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
578 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
579 |
+
i += 2
|
580 |
+
|
581 |
+
image = skip
|
582 |
+
|
583 |
+
if return_latents:
|
584 |
+
return image, latent
|
585 |
+
else:
|
586 |
+
return image, None
|
587 |
+
|
588 |
+
|
589 |
+
class ScaledLeakyReLU(nn.Module):
|
590 |
+
"""Scaled LeakyReLU.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
negative_slope (float): Negative slope. Default: 0.2.
|
594 |
+
"""
|
595 |
+
|
596 |
+
def __init__(self, negative_slope=0.2):
|
597 |
+
super(ScaledLeakyReLU, self).__init__()
|
598 |
+
self.negative_slope = negative_slope
|
599 |
+
|
600 |
+
def forward(self, x):
|
601 |
+
out = F.leaky_relu(x, negative_slope=self.negative_slope)
|
602 |
+
return out * math.sqrt(2)
|
603 |
+
|
604 |
+
|
605 |
+
class EqualConv2d(nn.Module):
|
606 |
+
"""Equalized Linear as StyleGAN2.
|
607 |
+
|
608 |
+
Args:
|
609 |
+
in_channels (int): Channel number of the input.
|
610 |
+
out_channels (int): Channel number of the output.
|
611 |
+
kernel_size (int): Size of the convolving kernel.
|
612 |
+
stride (int): Stride of the convolution. Default: 1
|
613 |
+
padding (int): Zero-padding added to both sides of the input.
|
614 |
+
Default: 0.
|
615 |
+
bias (bool): If ``True``, adds a learnable bias to the output.
|
616 |
+
Default: ``True``.
|
617 |
+
bias_init_val (float): Bias initialized value. Default: 0.
|
618 |
+
"""
|
619 |
+
|
620 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
|
621 |
+
super(EqualConv2d, self).__init__()
|
622 |
+
self.in_channels = in_channels
|
623 |
+
self.out_channels = out_channels
|
624 |
+
self.kernel_size = kernel_size
|
625 |
+
self.stride = stride
|
626 |
+
self.padding = padding
|
627 |
+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
|
628 |
+
|
629 |
+
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
|
630 |
+
if bias:
|
631 |
+
self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
|
632 |
+
else:
|
633 |
+
self.register_parameter('bias', None)
|
634 |
+
|
635 |
+
def forward(self, x):
|
636 |
+
out = F.conv2d(
|
637 |
+
x,
|
638 |
+
self.weight * self.scale,
|
639 |
+
bias=self.bias,
|
640 |
+
stride=self.stride,
|
641 |
+
padding=self.padding,
|
642 |
+
)
|
643 |
+
|
644 |
+
return out
|
645 |
+
|
646 |
+
def __repr__(self):
|
647 |
+
return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
|
648 |
+
f'out_channels={self.out_channels}, '
|
649 |
+
f'kernel_size={self.kernel_size},'
|
650 |
+
f' stride={self.stride}, padding={self.padding}, '
|
651 |
+
f'bias={self.bias is not None})')
|
652 |
+
|
653 |
+
|
654 |
+
class ConvLayer(nn.Sequential):
|
655 |
+
"""Conv Layer used in StyleGAN2 Discriminator.
|
656 |
+
|
657 |
+
Args:
|
658 |
+
in_channels (int): Channel number of the input.
|
659 |
+
out_channels (int): Channel number of the output.
|
660 |
+
kernel_size (int): Kernel size.
|
661 |
+
downsample (bool): Whether downsample by a factor of 2.
|
662 |
+
Default: False.
|
663 |
+
resample_kernel (list[int]): A list indicating the 1D resample
|
664 |
+
kernel magnitude. A cross production will be applied to
|
665 |
+
extent 1D resample kernel to 2D resample kernel.
|
666 |
+
Default: (1, 3, 3, 1).
|
667 |
+
bias (bool): Whether with bias. Default: True.
|
668 |
+
activate (bool): Whether use activateion. Default: True.
|
669 |
+
"""
|
670 |
+
|
671 |
+
def __init__(self,
|
672 |
+
in_channels,
|
673 |
+
out_channels,
|
674 |
+
kernel_size,
|
675 |
+
downsample=False,
|
676 |
+
resample_kernel=(1, 3, 3, 1),
|
677 |
+
bias=True,
|
678 |
+
activate=True):
|
679 |
+
layers = []
|
680 |
+
# downsample
|
681 |
+
if downsample:
|
682 |
+
layers.append(
|
683 |
+
UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
|
684 |
+
stride = 2
|
685 |
+
self.padding = 0
|
686 |
+
else:
|
687 |
+
stride = 1
|
688 |
+
self.padding = kernel_size // 2
|
689 |
+
# conv
|
690 |
+
layers.append(
|
691 |
+
EqualConv2d(
|
692 |
+
in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
|
693 |
+
and not activate))
|
694 |
+
# activation
|
695 |
+
if activate:
|
696 |
+
if bias:
|
697 |
+
layers.append(FusedLeakyReLU(out_channels))
|
698 |
+
else:
|
699 |
+
layers.append(ScaledLeakyReLU(0.2))
|
700 |
+
|
701 |
+
super(ConvLayer, self).__init__(*layers)
|
702 |
+
|
703 |
+
|
704 |
+
class ResBlock(nn.Module):
|
705 |
+
"""Residual block used in StyleGAN2 Discriminator.
|
706 |
+
|
707 |
+
Args:
|
708 |
+
in_channels (int): Channel number of the input.
|
709 |
+
out_channels (int): Channel number of the output.
|
710 |
+
resample_kernel (list[int]): A list indicating the 1D resample
|
711 |
+
kernel magnitude. A cross production will be applied to
|
712 |
+
extent 1D resample kernel to 2D resample kernel.
|
713 |
+
Default: (1, 3, 3, 1).
|
714 |
+
"""
|
715 |
+
|
716 |
+
def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
|
717 |
+
super(ResBlock, self).__init__()
|
718 |
+
|
719 |
+
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
|
720 |
+
self.conv2 = ConvLayer(
|
721 |
+
in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
|
722 |
+
self.skip = ConvLayer(
|
723 |
+
in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
|
724 |
+
|
725 |
+
def forward(self, x):
|
726 |
+
out = self.conv1(x)
|
727 |
+
out = self.conv2(out)
|
728 |
+
skip = self.skip(x)
|
729 |
+
out = (out + skip) / math.sqrt(2)
|
730 |
+
return out
|
731 |
+
|
732 |
+
|
733 |
+
@ARCH_REGISTRY.register()
|
734 |
+
class StyleGAN2Discriminator(nn.Module):
|
735 |
+
"""StyleGAN2 Discriminator.
|
736 |
+
|
737 |
+
Args:
|
738 |
+
out_size (int): The spatial size of outputs.
|
739 |
+
channel_multiplier (int): Channel multiplier for large networks of
|
740 |
+
StyleGAN2. Default: 2.
|
741 |
+
resample_kernel (list[int]): A list indicating the 1D resample kernel
|
742 |
+
magnitude. A cross production will be applied to extent 1D resample
|
743 |
+
kernel to 2D resample kernel. Default: (1, 3, 3, 1).
|
744 |
+
stddev_group (int): For group stddev statistics. Default: 4.
|
745 |
+
narrow (float): Narrow ratio for channels. Default: 1.0.
|
746 |
+
"""
|
747 |
+
|
748 |
+
def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
|
749 |
+
super(StyleGAN2Discriminator, self).__init__()
|
750 |
+
|
751 |
+
channels = {
|
752 |
+
'4': int(512 * narrow),
|
753 |
+
'8': int(512 * narrow),
|
754 |
+
'16': int(512 * narrow),
|
755 |
+
'32': int(512 * narrow),
|
756 |
+
'64': int(256 * channel_multiplier * narrow),
|
757 |
+
'128': int(128 * channel_multiplier * narrow),
|
758 |
+
'256': int(64 * channel_multiplier * narrow),
|
759 |
+
'512': int(32 * channel_multiplier * narrow),
|
760 |
+
'1024': int(16 * channel_multiplier * narrow)
|
761 |
+
}
|
762 |
+
|
763 |
+
log_size = int(math.log(out_size, 2))
|
764 |
+
|
765 |
+
conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
|
766 |
+
|
767 |
+
in_channels = channels[f'{out_size}']
|
768 |
+
for i in range(log_size, 2, -1):
|
769 |
+
out_channels = channels[f'{2**(i - 1)}']
|
770 |
+
conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
|
771 |
+
in_channels = out_channels
|
772 |
+
self.conv_body = nn.Sequential(*conv_body)
|
773 |
+
|
774 |
+
self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
|
775 |
+
self.final_linear = nn.Sequential(
|
776 |
+
EqualLinear(
|
777 |
+
channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
|
778 |
+
EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
|
779 |
+
)
|
780 |
+
self.stddev_group = stddev_group
|
781 |
+
self.stddev_feat = 1
|
782 |
+
|
783 |
+
def forward(self, x):
|
784 |
+
out = self.conv_body(x)
|
785 |
+
|
786 |
+
b, c, h, w = out.shape
|
787 |
+
# concatenate a group stddev statistics to out
|
788 |
+
group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
|
789 |
+
stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
|
790 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
791 |
+
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
|
792 |
+
stddev = stddev.repeat(group, 1, h, w)
|
793 |
+
out = torch.cat([out, stddev], 1)
|
794 |
+
|
795 |
+
out = self.final_conv(out)
|
796 |
+
out = out.view(b, -1)
|
797 |
+
out = self.final_linear(out)
|
798 |
+
|
799 |
+
return out
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py
ADDED
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from https://github.com/JingyunLiang/SwinIR
|
2 |
+
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
|
3 |
+
# Originally Written by Ze Liu, Modified by Jingyun Liang.
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.utils.checkpoint as checkpoint
|
9 |
+
|
10 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
11 |
+
from .arch_util import to_2tuple, trunc_normal_
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
15 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
16 |
+
|
17 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
18 |
+
"""
|
19 |
+
if drop_prob == 0. or not training:
|
20 |
+
return x
|
21 |
+
keep_prob = 1 - drop_prob
|
22 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
23 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
24 |
+
random_tensor.floor_() # binarize
|
25 |
+
output = x.div(keep_prob) * random_tensor
|
26 |
+
return output
|
27 |
+
|
28 |
+
|
29 |
+
class DropPath(nn.Module):
|
30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
31 |
+
|
32 |
+
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, drop_prob=None):
|
36 |
+
super(DropPath, self).__init__()
|
37 |
+
self.drop_prob = drop_prob
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
return drop_path(x, self.drop_prob, self.training)
|
41 |
+
|
42 |
+
|
43 |
+
class Mlp(nn.Module):
|
44 |
+
|
45 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
46 |
+
super().__init__()
|
47 |
+
out_features = out_features or in_features
|
48 |
+
hidden_features = hidden_features or in_features
|
49 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
50 |
+
self.act = act_layer()
|
51 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
52 |
+
self.drop = nn.Dropout(drop)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = self.fc1(x)
|
56 |
+
x = self.act(x)
|
57 |
+
x = self.drop(x)
|
58 |
+
x = self.fc2(x)
|
59 |
+
x = self.drop(x)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
def window_partition(x, window_size):
|
64 |
+
"""
|
65 |
+
Args:
|
66 |
+
x: (b, h, w, c)
|
67 |
+
window_size (int): window size
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
windows: (num_windows*b, window_size, window_size, c)
|
71 |
+
"""
|
72 |
+
b, h, w, c = x.shape
|
73 |
+
x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
|
74 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
|
75 |
+
return windows
|
76 |
+
|
77 |
+
|
78 |
+
def window_reverse(windows, window_size, h, w):
|
79 |
+
"""
|
80 |
+
Args:
|
81 |
+
windows: (num_windows*b, window_size, window_size, c)
|
82 |
+
window_size (int): Window size
|
83 |
+
h (int): Height of image
|
84 |
+
w (int): Width of image
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
x: (b, h, w, c)
|
88 |
+
"""
|
89 |
+
b = int(windows.shape[0] / (h * w / window_size / window_size))
|
90 |
+
x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
|
91 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class WindowAttention(nn.Module):
|
96 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
97 |
+
It supports both of shifted and non-shifted window.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
dim (int): Number of input channels.
|
101 |
+
window_size (tuple[int]): The height and width of the window.
|
102 |
+
num_heads (int): Number of attention heads.
|
103 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
104 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
105 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
106 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
|
110 |
+
|
111 |
+
super().__init__()
|
112 |
+
self.dim = dim
|
113 |
+
self.window_size = window_size # Wh, Ww
|
114 |
+
self.num_heads = num_heads
|
115 |
+
head_dim = dim // num_heads
|
116 |
+
self.scale = qk_scale or head_dim**-0.5
|
117 |
+
|
118 |
+
# define a parameter table of relative position bias
|
119 |
+
self.relative_position_bias_table = nn.Parameter(
|
120 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
121 |
+
|
122 |
+
# get pair-wise relative position index for each token inside the window
|
123 |
+
coords_h = torch.arange(self.window_size[0])
|
124 |
+
coords_w = torch.arange(self.window_size[1])
|
125 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
126 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
127 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
128 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
129 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
130 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
131 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
132 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
133 |
+
self.register_buffer('relative_position_index', relative_position_index)
|
134 |
+
|
135 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
136 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
137 |
+
self.proj = nn.Linear(dim, dim)
|
138 |
+
|
139 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
140 |
+
|
141 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
142 |
+
self.softmax = nn.Softmax(dim=-1)
|
143 |
+
|
144 |
+
def forward(self, x, mask=None):
|
145 |
+
"""
|
146 |
+
Args:
|
147 |
+
x: input features with shape of (num_windows*b, n, c)
|
148 |
+
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
149 |
+
"""
|
150 |
+
b_, n, c = x.shape
|
151 |
+
qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
|
152 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
153 |
+
|
154 |
+
q = q * self.scale
|
155 |
+
attn = (q @ k.transpose(-2, -1))
|
156 |
+
|
157 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
158 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
159 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
160 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
161 |
+
|
162 |
+
if mask is not None:
|
163 |
+
nw = mask.shape[0]
|
164 |
+
attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
|
165 |
+
attn = attn.view(-1, self.num_heads, n, n)
|
166 |
+
attn = self.softmax(attn)
|
167 |
+
else:
|
168 |
+
attn = self.softmax(attn)
|
169 |
+
|
170 |
+
attn = self.attn_drop(attn)
|
171 |
+
|
172 |
+
x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
|
173 |
+
x = self.proj(x)
|
174 |
+
x = self.proj_drop(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
def extra_repr(self) -> str:
|
178 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
179 |
+
|
180 |
+
def flops(self, n):
|
181 |
+
# calculate flops for 1 window with token length of n
|
182 |
+
flops = 0
|
183 |
+
# qkv = self.qkv(x)
|
184 |
+
flops += n * self.dim * 3 * self.dim
|
185 |
+
# attn = (q @ k.transpose(-2, -1))
|
186 |
+
flops += self.num_heads * n * (self.dim // self.num_heads) * n
|
187 |
+
# x = (attn @ v)
|
188 |
+
flops += self.num_heads * n * n * (self.dim // self.num_heads)
|
189 |
+
# x = self.proj(x)
|
190 |
+
flops += n * self.dim * self.dim
|
191 |
+
return flops
|
192 |
+
|
193 |
+
|
194 |
+
class SwinTransformerBlock(nn.Module):
|
195 |
+
r""" Swin Transformer Block.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
dim (int): Number of input channels.
|
199 |
+
input_resolution (tuple[int]): Input resolution.
|
200 |
+
num_heads (int): Number of attention heads.
|
201 |
+
window_size (int): Window size.
|
202 |
+
shift_size (int): Shift size for SW-MSA.
|
203 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
204 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
205 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
206 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
207 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
208 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
209 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
210 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(self,
|
214 |
+
dim,
|
215 |
+
input_resolution,
|
216 |
+
num_heads,
|
217 |
+
window_size=7,
|
218 |
+
shift_size=0,
|
219 |
+
mlp_ratio=4.,
|
220 |
+
qkv_bias=True,
|
221 |
+
qk_scale=None,
|
222 |
+
drop=0.,
|
223 |
+
attn_drop=0.,
|
224 |
+
drop_path=0.,
|
225 |
+
act_layer=nn.GELU,
|
226 |
+
norm_layer=nn.LayerNorm):
|
227 |
+
super().__init__()
|
228 |
+
self.dim = dim
|
229 |
+
self.input_resolution = input_resolution
|
230 |
+
self.num_heads = num_heads
|
231 |
+
self.window_size = window_size
|
232 |
+
self.shift_size = shift_size
|
233 |
+
self.mlp_ratio = mlp_ratio
|
234 |
+
if min(self.input_resolution) <= self.window_size:
|
235 |
+
# if window size is larger than input resolution, we don't partition windows
|
236 |
+
self.shift_size = 0
|
237 |
+
self.window_size = min(self.input_resolution)
|
238 |
+
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
239 |
+
|
240 |
+
self.norm1 = norm_layer(dim)
|
241 |
+
self.attn = WindowAttention(
|
242 |
+
dim,
|
243 |
+
window_size=to_2tuple(self.window_size),
|
244 |
+
num_heads=num_heads,
|
245 |
+
qkv_bias=qkv_bias,
|
246 |
+
qk_scale=qk_scale,
|
247 |
+
attn_drop=attn_drop,
|
248 |
+
proj_drop=drop)
|
249 |
+
|
250 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
251 |
+
self.norm2 = norm_layer(dim)
|
252 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
253 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
254 |
+
|
255 |
+
if self.shift_size > 0:
|
256 |
+
attn_mask = self.calculate_mask(self.input_resolution)
|
257 |
+
else:
|
258 |
+
attn_mask = None
|
259 |
+
|
260 |
+
self.register_buffer('attn_mask', attn_mask)
|
261 |
+
|
262 |
+
def calculate_mask(self, x_size):
|
263 |
+
# calculate attention mask for SW-MSA
|
264 |
+
h, w = x_size
|
265 |
+
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
|
266 |
+
h_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
267 |
+
-self.shift_size), slice(-self.shift_size, None))
|
268 |
+
w_slices = (slice(0, -self.window_size), slice(-self.window_size,
|
269 |
+
-self.shift_size), slice(-self.shift_size, None))
|
270 |
+
cnt = 0
|
271 |
+
for h in h_slices:
|
272 |
+
for w in w_slices:
|
273 |
+
img_mask[:, h, w, :] = cnt
|
274 |
+
cnt += 1
|
275 |
+
|
276 |
+
mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
|
277 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
278 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
279 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
280 |
+
|
281 |
+
return attn_mask
|
282 |
+
|
283 |
+
def forward(self, x, x_size):
|
284 |
+
h, w = x_size
|
285 |
+
b, _, c = x.shape
|
286 |
+
# assert seq_len == h * w, "input feature has wrong size"
|
287 |
+
|
288 |
+
shortcut = x
|
289 |
+
x = self.norm1(x)
|
290 |
+
x = x.view(b, h, w, c)
|
291 |
+
|
292 |
+
# cyclic shift
|
293 |
+
if self.shift_size > 0:
|
294 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
295 |
+
else:
|
296 |
+
shifted_x = x
|
297 |
+
|
298 |
+
# partition windows
|
299 |
+
x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
|
300 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
|
301 |
+
|
302 |
+
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
|
303 |
+
if self.input_resolution == x_size:
|
304 |
+
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
|
305 |
+
else:
|
306 |
+
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
307 |
+
|
308 |
+
# merge windows
|
309 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
|
310 |
+
shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
|
311 |
+
|
312 |
+
# reverse cyclic shift
|
313 |
+
if self.shift_size > 0:
|
314 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
315 |
+
else:
|
316 |
+
x = shifted_x
|
317 |
+
x = x.view(b, h * w, c)
|
318 |
+
|
319 |
+
# FFN
|
320 |
+
x = shortcut + self.drop_path(x)
|
321 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
322 |
+
|
323 |
+
return x
|
324 |
+
|
325 |
+
def extra_repr(self) -> str:
|
326 |
+
return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
|
327 |
+
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
|
328 |
+
|
329 |
+
def flops(self):
|
330 |
+
flops = 0
|
331 |
+
h, w = self.input_resolution
|
332 |
+
# norm1
|
333 |
+
flops += self.dim * h * w
|
334 |
+
# W-MSA/SW-MSA
|
335 |
+
nw = h * w / self.window_size / self.window_size
|
336 |
+
flops += nw * self.attn.flops(self.window_size * self.window_size)
|
337 |
+
# mlp
|
338 |
+
flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
|
339 |
+
# norm2
|
340 |
+
flops += self.dim * h * w
|
341 |
+
return flops
|
342 |
+
|
343 |
+
|
344 |
+
class PatchMerging(nn.Module):
|
345 |
+
r""" Patch Merging Layer.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
input_resolution (tuple[int]): Resolution of input feature.
|
349 |
+
dim (int): Number of input channels.
|
350 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
351 |
+
"""
|
352 |
+
|
353 |
+
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
354 |
+
super().__init__()
|
355 |
+
self.input_resolution = input_resolution
|
356 |
+
self.dim = dim
|
357 |
+
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
358 |
+
self.norm = norm_layer(4 * dim)
|
359 |
+
|
360 |
+
def forward(self, x):
|
361 |
+
"""
|
362 |
+
x: b, h*w, c
|
363 |
+
"""
|
364 |
+
h, w = self.input_resolution
|
365 |
+
b, seq_len, c = x.shape
|
366 |
+
assert seq_len == h * w, 'input feature has wrong size'
|
367 |
+
assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
|
368 |
+
|
369 |
+
x = x.view(b, h, w, c)
|
370 |
+
|
371 |
+
x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
|
372 |
+
x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
|
373 |
+
x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
|
374 |
+
x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
|
375 |
+
x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
|
376 |
+
x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
|
377 |
+
|
378 |
+
x = self.norm(x)
|
379 |
+
x = self.reduction(x)
|
380 |
+
|
381 |
+
return x
|
382 |
+
|
383 |
+
def extra_repr(self) -> str:
|
384 |
+
return f'input_resolution={self.input_resolution}, dim={self.dim}'
|
385 |
+
|
386 |
+
def flops(self):
|
387 |
+
h, w = self.input_resolution
|
388 |
+
flops = h * w * self.dim
|
389 |
+
flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
|
390 |
+
return flops
|
391 |
+
|
392 |
+
|
393 |
+
class BasicLayer(nn.Module):
|
394 |
+
""" A basic Swin Transformer layer for one stage.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
dim (int): Number of input channels.
|
398 |
+
input_resolution (tuple[int]): Input resolution.
|
399 |
+
depth (int): Number of blocks.
|
400 |
+
num_heads (int): Number of attention heads.
|
401 |
+
window_size (int): Local window size.
|
402 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
403 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
404 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
405 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
406 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
407 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
408 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
409 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
410 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
411 |
+
"""
|
412 |
+
|
413 |
+
def __init__(self,
|
414 |
+
dim,
|
415 |
+
input_resolution,
|
416 |
+
depth,
|
417 |
+
num_heads,
|
418 |
+
window_size,
|
419 |
+
mlp_ratio=4.,
|
420 |
+
qkv_bias=True,
|
421 |
+
qk_scale=None,
|
422 |
+
drop=0.,
|
423 |
+
attn_drop=0.,
|
424 |
+
drop_path=0.,
|
425 |
+
norm_layer=nn.LayerNorm,
|
426 |
+
downsample=None,
|
427 |
+
use_checkpoint=False):
|
428 |
+
|
429 |
+
super().__init__()
|
430 |
+
self.dim = dim
|
431 |
+
self.input_resolution = input_resolution
|
432 |
+
self.depth = depth
|
433 |
+
self.use_checkpoint = use_checkpoint
|
434 |
+
|
435 |
+
# build blocks
|
436 |
+
self.blocks = nn.ModuleList([
|
437 |
+
SwinTransformerBlock(
|
438 |
+
dim=dim,
|
439 |
+
input_resolution=input_resolution,
|
440 |
+
num_heads=num_heads,
|
441 |
+
window_size=window_size,
|
442 |
+
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
443 |
+
mlp_ratio=mlp_ratio,
|
444 |
+
qkv_bias=qkv_bias,
|
445 |
+
qk_scale=qk_scale,
|
446 |
+
drop=drop,
|
447 |
+
attn_drop=attn_drop,
|
448 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
449 |
+
norm_layer=norm_layer) for i in range(depth)
|
450 |
+
])
|
451 |
+
|
452 |
+
# patch merging layer
|
453 |
+
if downsample is not None:
|
454 |
+
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
|
455 |
+
else:
|
456 |
+
self.downsample = None
|
457 |
+
|
458 |
+
def forward(self, x, x_size):
|
459 |
+
for blk in self.blocks:
|
460 |
+
if self.use_checkpoint:
|
461 |
+
x = checkpoint.checkpoint(blk, x)
|
462 |
+
else:
|
463 |
+
x = blk(x, x_size)
|
464 |
+
if self.downsample is not None:
|
465 |
+
x = self.downsample(x)
|
466 |
+
return x
|
467 |
+
|
468 |
+
def extra_repr(self) -> str:
|
469 |
+
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
470 |
+
|
471 |
+
def flops(self):
|
472 |
+
flops = 0
|
473 |
+
for blk in self.blocks:
|
474 |
+
flops += blk.flops()
|
475 |
+
if self.downsample is not None:
|
476 |
+
flops += self.downsample.flops()
|
477 |
+
return flops
|
478 |
+
|
479 |
+
|
480 |
+
class RSTB(nn.Module):
|
481 |
+
"""Residual Swin Transformer Block (RSTB).
|
482 |
+
|
483 |
+
Args:
|
484 |
+
dim (int): Number of input channels.
|
485 |
+
input_resolution (tuple[int]): Input resolution.
|
486 |
+
depth (int): Number of blocks.
|
487 |
+
num_heads (int): Number of attention heads.
|
488 |
+
window_size (int): Local window size.
|
489 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
490 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
491 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
492 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
493 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
494 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
495 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
496 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
497 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
498 |
+
img_size: Input image size.
|
499 |
+
patch_size: Patch size.
|
500 |
+
resi_connection: The convolutional block before residual connection.
|
501 |
+
"""
|
502 |
+
|
503 |
+
def __init__(self,
|
504 |
+
dim,
|
505 |
+
input_resolution,
|
506 |
+
depth,
|
507 |
+
num_heads,
|
508 |
+
window_size,
|
509 |
+
mlp_ratio=4.,
|
510 |
+
qkv_bias=True,
|
511 |
+
qk_scale=None,
|
512 |
+
drop=0.,
|
513 |
+
attn_drop=0.,
|
514 |
+
drop_path=0.,
|
515 |
+
norm_layer=nn.LayerNorm,
|
516 |
+
downsample=None,
|
517 |
+
use_checkpoint=False,
|
518 |
+
img_size=224,
|
519 |
+
patch_size=4,
|
520 |
+
resi_connection='1conv'):
|
521 |
+
super(RSTB, self).__init__()
|
522 |
+
|
523 |
+
self.dim = dim
|
524 |
+
self.input_resolution = input_resolution
|
525 |
+
|
526 |
+
self.residual_group = BasicLayer(
|
527 |
+
dim=dim,
|
528 |
+
input_resolution=input_resolution,
|
529 |
+
depth=depth,
|
530 |
+
num_heads=num_heads,
|
531 |
+
window_size=window_size,
|
532 |
+
mlp_ratio=mlp_ratio,
|
533 |
+
qkv_bias=qkv_bias,
|
534 |
+
qk_scale=qk_scale,
|
535 |
+
drop=drop,
|
536 |
+
attn_drop=attn_drop,
|
537 |
+
drop_path=drop_path,
|
538 |
+
norm_layer=norm_layer,
|
539 |
+
downsample=downsample,
|
540 |
+
use_checkpoint=use_checkpoint)
|
541 |
+
|
542 |
+
if resi_connection == '1conv':
|
543 |
+
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
544 |
+
elif resi_connection == '3conv':
|
545 |
+
# to save parameters and memory
|
546 |
+
self.conv = nn.Sequential(
|
547 |
+
nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
548 |
+
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
549 |
+
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
550 |
+
|
551 |
+
self.patch_embed = PatchEmbed(
|
552 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
553 |
+
|
554 |
+
self.patch_unembed = PatchUnEmbed(
|
555 |
+
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
|
556 |
+
|
557 |
+
def forward(self, x, x_size):
|
558 |
+
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
|
559 |
+
|
560 |
+
def flops(self):
|
561 |
+
flops = 0
|
562 |
+
flops += self.residual_group.flops()
|
563 |
+
h, w = self.input_resolution
|
564 |
+
flops += h * w * self.dim * self.dim * 9
|
565 |
+
flops += self.patch_embed.flops()
|
566 |
+
flops += self.patch_unembed.flops()
|
567 |
+
|
568 |
+
return flops
|
569 |
+
|
570 |
+
|
571 |
+
class PatchEmbed(nn.Module):
|
572 |
+
r""" Image to Patch Embedding
|
573 |
+
|
574 |
+
Args:
|
575 |
+
img_size (int): Image size. Default: 224.
|
576 |
+
patch_size (int): Patch token size. Default: 4.
|
577 |
+
in_chans (int): Number of input image channels. Default: 3.
|
578 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
579 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
580 |
+
"""
|
581 |
+
|
582 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
583 |
+
super().__init__()
|
584 |
+
img_size = to_2tuple(img_size)
|
585 |
+
patch_size = to_2tuple(patch_size)
|
586 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
587 |
+
self.img_size = img_size
|
588 |
+
self.patch_size = patch_size
|
589 |
+
self.patches_resolution = patches_resolution
|
590 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
591 |
+
|
592 |
+
self.in_chans = in_chans
|
593 |
+
self.embed_dim = embed_dim
|
594 |
+
|
595 |
+
if norm_layer is not None:
|
596 |
+
self.norm = norm_layer(embed_dim)
|
597 |
+
else:
|
598 |
+
self.norm = None
|
599 |
+
|
600 |
+
def forward(self, x):
|
601 |
+
x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
|
602 |
+
if self.norm is not None:
|
603 |
+
x = self.norm(x)
|
604 |
+
return x
|
605 |
+
|
606 |
+
def flops(self):
|
607 |
+
flops = 0
|
608 |
+
h, w = self.img_size
|
609 |
+
if self.norm is not None:
|
610 |
+
flops += h * w * self.embed_dim
|
611 |
+
return flops
|
612 |
+
|
613 |
+
|
614 |
+
class PatchUnEmbed(nn.Module):
|
615 |
+
r""" Image to Patch Unembedding
|
616 |
+
|
617 |
+
Args:
|
618 |
+
img_size (int): Image size. Default: 224.
|
619 |
+
patch_size (int): Patch token size. Default: 4.
|
620 |
+
in_chans (int): Number of input image channels. Default: 3.
|
621 |
+
embed_dim (int): Number of linear projection output channels. Default: 96.
|
622 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
623 |
+
"""
|
624 |
+
|
625 |
+
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
|
626 |
+
super().__init__()
|
627 |
+
img_size = to_2tuple(img_size)
|
628 |
+
patch_size = to_2tuple(patch_size)
|
629 |
+
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
|
630 |
+
self.img_size = img_size
|
631 |
+
self.patch_size = patch_size
|
632 |
+
self.patches_resolution = patches_resolution
|
633 |
+
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
634 |
+
|
635 |
+
self.in_chans = in_chans
|
636 |
+
self.embed_dim = embed_dim
|
637 |
+
|
638 |
+
def forward(self, x, x_size):
|
639 |
+
x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
|
640 |
+
return x
|
641 |
+
|
642 |
+
def flops(self):
|
643 |
+
flops = 0
|
644 |
+
return flops
|
645 |
+
|
646 |
+
|
647 |
+
class Upsample(nn.Sequential):
|
648 |
+
"""Upsample module.
|
649 |
+
|
650 |
+
Args:
|
651 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
652 |
+
num_feat (int): Channel number of intermediate features.
|
653 |
+
"""
|
654 |
+
|
655 |
+
def __init__(self, scale, num_feat):
|
656 |
+
m = []
|
657 |
+
if (scale & (scale - 1)) == 0: # scale = 2^n
|
658 |
+
for _ in range(int(math.log(scale, 2))):
|
659 |
+
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
660 |
+
m.append(nn.PixelShuffle(2))
|
661 |
+
elif scale == 3:
|
662 |
+
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
663 |
+
m.append(nn.PixelShuffle(3))
|
664 |
+
else:
|
665 |
+
raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
|
666 |
+
super(Upsample, self).__init__(*m)
|
667 |
+
|
668 |
+
|
669 |
+
class UpsampleOneStep(nn.Sequential):
|
670 |
+
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
671 |
+
Used in lightweight SR to save parameters.
|
672 |
+
|
673 |
+
Args:
|
674 |
+
scale (int): Scale factor. Supported scales: 2^n and 3.
|
675 |
+
num_feat (int): Channel number of intermediate features.
|
676 |
+
|
677 |
+
"""
|
678 |
+
|
679 |
+
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
680 |
+
self.num_feat = num_feat
|
681 |
+
self.input_resolution = input_resolution
|
682 |
+
m = []
|
683 |
+
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
684 |
+
m.append(nn.PixelShuffle(scale))
|
685 |
+
super(UpsampleOneStep, self).__init__(*m)
|
686 |
+
|
687 |
+
def flops(self):
|
688 |
+
h, w = self.input_resolution
|
689 |
+
flops = h * w * self.num_feat * 3 * 9
|
690 |
+
return flops
|
691 |
+
|
692 |
+
|
693 |
+
@ARCH_REGISTRY.register()
|
694 |
+
class SwinIR(nn.Module):
|
695 |
+
r""" SwinIR
|
696 |
+
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
|
697 |
+
|
698 |
+
Args:
|
699 |
+
img_size (int | tuple(int)): Input image size. Default 64
|
700 |
+
patch_size (int | tuple(int)): Patch size. Default: 1
|
701 |
+
in_chans (int): Number of input image channels. Default: 3
|
702 |
+
embed_dim (int): Patch embedding dimension. Default: 96
|
703 |
+
depths (tuple(int)): Depth of each Swin Transformer layer.
|
704 |
+
num_heads (tuple(int)): Number of attention heads in different layers.
|
705 |
+
window_size (int): Window size. Default: 7
|
706 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
707 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
708 |
+
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
|
709 |
+
drop_rate (float): Dropout rate. Default: 0
|
710 |
+
attn_drop_rate (float): Attention dropout rate. Default: 0
|
711 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.1
|
712 |
+
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
|
713 |
+
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
|
714 |
+
patch_norm (bool): If True, add normalization after patch embedding. Default: True
|
715 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
|
716 |
+
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
|
717 |
+
img_range: Image range. 1. or 255.
|
718 |
+
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
|
719 |
+
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
|
720 |
+
"""
|
721 |
+
|
722 |
+
def __init__(self,
|
723 |
+
img_size=64,
|
724 |
+
patch_size=1,
|
725 |
+
in_chans=3,
|
726 |
+
embed_dim=96,
|
727 |
+
depths=(6, 6, 6, 6),
|
728 |
+
num_heads=(6, 6, 6, 6),
|
729 |
+
window_size=7,
|
730 |
+
mlp_ratio=4.,
|
731 |
+
qkv_bias=True,
|
732 |
+
qk_scale=None,
|
733 |
+
drop_rate=0.,
|
734 |
+
attn_drop_rate=0.,
|
735 |
+
drop_path_rate=0.1,
|
736 |
+
norm_layer=nn.LayerNorm,
|
737 |
+
ape=False,
|
738 |
+
patch_norm=True,
|
739 |
+
use_checkpoint=False,
|
740 |
+
upscale=2,
|
741 |
+
img_range=1.,
|
742 |
+
upsampler='',
|
743 |
+
resi_connection='1conv',
|
744 |
+
**kwargs):
|
745 |
+
super(SwinIR, self).__init__()
|
746 |
+
num_in_ch = in_chans
|
747 |
+
num_out_ch = in_chans
|
748 |
+
num_feat = 64
|
749 |
+
self.img_range = img_range
|
750 |
+
if in_chans == 3:
|
751 |
+
rgb_mean = (0.4488, 0.4371, 0.4040)
|
752 |
+
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
|
753 |
+
else:
|
754 |
+
self.mean = torch.zeros(1, 1, 1, 1)
|
755 |
+
self.upscale = upscale
|
756 |
+
self.upsampler = upsampler
|
757 |
+
|
758 |
+
# ------------------------- 1, shallow feature extraction ------------------------- #
|
759 |
+
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
|
760 |
+
|
761 |
+
# ------------------------- 2, deep feature extraction ------------------------- #
|
762 |
+
self.num_layers = len(depths)
|
763 |
+
self.embed_dim = embed_dim
|
764 |
+
self.ape = ape
|
765 |
+
self.patch_norm = patch_norm
|
766 |
+
self.num_features = embed_dim
|
767 |
+
self.mlp_ratio = mlp_ratio
|
768 |
+
|
769 |
+
# split image into non-overlapping patches
|
770 |
+
self.patch_embed = PatchEmbed(
|
771 |
+
img_size=img_size,
|
772 |
+
patch_size=patch_size,
|
773 |
+
in_chans=embed_dim,
|
774 |
+
embed_dim=embed_dim,
|
775 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
776 |
+
num_patches = self.patch_embed.num_patches
|
777 |
+
patches_resolution = self.patch_embed.patches_resolution
|
778 |
+
self.patches_resolution = patches_resolution
|
779 |
+
|
780 |
+
# merge non-overlapping patches into image
|
781 |
+
self.patch_unembed = PatchUnEmbed(
|
782 |
+
img_size=img_size,
|
783 |
+
patch_size=patch_size,
|
784 |
+
in_chans=embed_dim,
|
785 |
+
embed_dim=embed_dim,
|
786 |
+
norm_layer=norm_layer if self.patch_norm else None)
|
787 |
+
|
788 |
+
# absolute position embedding
|
789 |
+
if self.ape:
|
790 |
+
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
791 |
+
trunc_normal_(self.absolute_pos_embed, std=.02)
|
792 |
+
|
793 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
794 |
+
|
795 |
+
# stochastic depth
|
796 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
797 |
+
|
798 |
+
# build Residual Swin Transformer blocks (RSTB)
|
799 |
+
self.layers = nn.ModuleList()
|
800 |
+
for i_layer in range(self.num_layers):
|
801 |
+
layer = RSTB(
|
802 |
+
dim=embed_dim,
|
803 |
+
input_resolution=(patches_resolution[0], patches_resolution[1]),
|
804 |
+
depth=depths[i_layer],
|
805 |
+
num_heads=num_heads[i_layer],
|
806 |
+
window_size=window_size,
|
807 |
+
mlp_ratio=self.mlp_ratio,
|
808 |
+
qkv_bias=qkv_bias,
|
809 |
+
qk_scale=qk_scale,
|
810 |
+
drop=drop_rate,
|
811 |
+
attn_drop=attn_drop_rate,
|
812 |
+
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
|
813 |
+
norm_layer=norm_layer,
|
814 |
+
downsample=None,
|
815 |
+
use_checkpoint=use_checkpoint,
|
816 |
+
img_size=img_size,
|
817 |
+
patch_size=patch_size,
|
818 |
+
resi_connection=resi_connection)
|
819 |
+
self.layers.append(layer)
|
820 |
+
self.norm = norm_layer(self.num_features)
|
821 |
+
|
822 |
+
# build the last conv layer in deep feature extraction
|
823 |
+
if resi_connection == '1conv':
|
824 |
+
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
|
825 |
+
elif resi_connection == '3conv':
|
826 |
+
# to save parameters and memory
|
827 |
+
self.conv_after_body = nn.Sequential(
|
828 |
+
nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
829 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
830 |
+
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
|
831 |
+
|
832 |
+
# ------------------------- 3, high quality image reconstruction ------------------------- #
|
833 |
+
if self.upsampler == 'pixelshuffle':
|
834 |
+
# for classical SR
|
835 |
+
self.conv_before_upsample = nn.Sequential(
|
836 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
837 |
+
self.upsample = Upsample(upscale, num_feat)
|
838 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
839 |
+
elif self.upsampler == 'pixelshuffledirect':
|
840 |
+
# for lightweight SR (to save parameters)
|
841 |
+
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
|
842 |
+
(patches_resolution[0], patches_resolution[1]))
|
843 |
+
elif self.upsampler == 'nearest+conv':
|
844 |
+
# for real-world SR (less artifacts)
|
845 |
+
assert self.upscale == 4, 'only support x4 now.'
|
846 |
+
self.conv_before_upsample = nn.Sequential(
|
847 |
+
nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
|
848 |
+
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
849 |
+
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
850 |
+
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
|
851 |
+
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
|
852 |
+
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
853 |
+
else:
|
854 |
+
# for image denoising and JPEG compression artifact reduction
|
855 |
+
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
|
856 |
+
|
857 |
+
self.apply(self._init_weights)
|
858 |
+
|
859 |
+
def _init_weights(self, m):
|
860 |
+
if isinstance(m, nn.Linear):
|
861 |
+
trunc_normal_(m.weight, std=.02)
|
862 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
863 |
+
nn.init.constant_(m.bias, 0)
|
864 |
+
elif isinstance(m, nn.LayerNorm):
|
865 |
+
nn.init.constant_(m.bias, 0)
|
866 |
+
nn.init.constant_(m.weight, 1.0)
|
867 |
+
|
868 |
+
@torch.jit.ignore
|
869 |
+
def no_weight_decay(self):
|
870 |
+
return {'absolute_pos_embed'}
|
871 |
+
|
872 |
+
@torch.jit.ignore
|
873 |
+
def no_weight_decay_keywords(self):
|
874 |
+
return {'relative_position_bias_table'}
|
875 |
+
|
876 |
+
def forward_features(self, x):
|
877 |
+
x_size = (x.shape[2], x.shape[3])
|
878 |
+
x = self.patch_embed(x)
|
879 |
+
if self.ape:
|
880 |
+
x = x + self.absolute_pos_embed
|
881 |
+
x = self.pos_drop(x)
|
882 |
+
|
883 |
+
for layer in self.layers:
|
884 |
+
x = layer(x, x_size)
|
885 |
+
|
886 |
+
x = self.norm(x) # b seq_len c
|
887 |
+
x = self.patch_unembed(x, x_size)
|
888 |
+
|
889 |
+
return x
|
890 |
+
|
891 |
+
def forward(self, x):
|
892 |
+
self.mean = self.mean.type_as(x)
|
893 |
+
x = (x - self.mean) * self.img_range
|
894 |
+
|
895 |
+
if self.upsampler == 'pixelshuffle':
|
896 |
+
# for classical SR
|
897 |
+
x = self.conv_first(x)
|
898 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
899 |
+
x = self.conv_before_upsample(x)
|
900 |
+
x = self.conv_last(self.upsample(x))
|
901 |
+
elif self.upsampler == 'pixelshuffledirect':
|
902 |
+
# for lightweight SR
|
903 |
+
x = self.conv_first(x)
|
904 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
905 |
+
x = self.upsample(x)
|
906 |
+
elif self.upsampler == 'nearest+conv':
|
907 |
+
# for real-world SR
|
908 |
+
x = self.conv_first(x)
|
909 |
+
x = self.conv_after_body(self.forward_features(x)) + x
|
910 |
+
x = self.conv_before_upsample(x)
|
911 |
+
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
912 |
+
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
|
913 |
+
x = self.conv_last(self.lrelu(self.conv_hr(x)))
|
914 |
+
else:
|
915 |
+
# for image denoising and JPEG compression artifact reduction
|
916 |
+
x_first = self.conv_first(x)
|
917 |
+
res = self.conv_after_body(self.forward_features(x_first)) + x_first
|
918 |
+
x = x + self.conv_last(res)
|
919 |
+
|
920 |
+
x = x / self.img_range + self.mean
|
921 |
+
|
922 |
+
return x
|
923 |
+
|
924 |
+
def flops(self):
|
925 |
+
flops = 0
|
926 |
+
h, w = self.patches_resolution
|
927 |
+
flops += h * w * 3 * self.embed_dim * 9
|
928 |
+
flops += self.patch_embed.flops()
|
929 |
+
for layer in self.layers:
|
930 |
+
flops += layer.flops()
|
931 |
+
flops += h * w * 3 * self.embed_dim * self.embed_dim
|
932 |
+
flops += self.upsample.flops()
|
933 |
+
return flops
|
934 |
+
|
935 |
+
|
936 |
+
if __name__ == '__main__':
|
937 |
+
upscale = 4
|
938 |
+
window_size = 8
|
939 |
+
height = (1024 // upscale // window_size + 1) * window_size
|
940 |
+
width = (720 // upscale // window_size + 1) * window_size
|
941 |
+
model = SwinIR(
|
942 |
+
upscale=2,
|
943 |
+
img_size=(height, width),
|
944 |
+
window_size=window_size,
|
945 |
+
img_range=1.,
|
946 |
+
depths=[6, 6, 6, 6],
|
947 |
+
embed_dim=60,
|
948 |
+
num_heads=[6, 6, 6, 6],
|
949 |
+
mlp_ratio=2,
|
950 |
+
upsampler='pixelshuffledirect')
|
951 |
+
print(model)
|
952 |
+
print(height, width, model.flops() / 1e9)
|
953 |
+
|
954 |
+
x = torch.randn((1, 3, height, width))
|
955 |
+
x = model(x)
|
956 |
+
print(x.shape)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
6 |
+
from .arch_util import flow_warp
|
7 |
+
|
8 |
+
|
9 |
+
class BasicModule(nn.Module):
|
10 |
+
"""Basic module of SPyNet.
|
11 |
+
|
12 |
+
Note that unlike the architecture in spynet_arch.py, the basic module
|
13 |
+
here contains batch normalization.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self):
|
17 |
+
super(BasicModule, self).__init__()
|
18 |
+
self.basic_module = nn.Sequential(
|
19 |
+
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
|
20 |
+
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
|
22 |
+
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
|
23 |
+
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
|
24 |
+
nn.BatchNorm2d(32), nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
|
26 |
+
nn.BatchNorm2d(16), nn.ReLU(inplace=True),
|
27 |
+
nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
|
28 |
+
|
29 |
+
def forward(self, tensor_input):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
|
33 |
+
8 channels contain:
|
34 |
+
[reference image (3), neighbor image (3), initial flow (2)].
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Tensor: Estimated flow with shape (b, 2, h, w)
|
38 |
+
"""
|
39 |
+
return self.basic_module(tensor_input)
|
40 |
+
|
41 |
+
|
42 |
+
class SPyNetTOF(nn.Module):
|
43 |
+
"""SPyNet architecture for TOF.
|
44 |
+
|
45 |
+
Note that this implementation is specifically for TOFlow. Please use
|
46 |
+
spynet_arch.py for general use. They differ in the following aspects:
|
47 |
+
1. The basic modules here contain BatchNorm.
|
48 |
+
2. Normalization and denormalization are not done here, as
|
49 |
+
they are done in TOFlow.
|
50 |
+
Paper:
|
51 |
+
Optical Flow Estimation using a Spatial Pyramid Network
|
52 |
+
Code reference:
|
53 |
+
https://github.com/Coldog2333/pytoflow
|
54 |
+
|
55 |
+
Args:
|
56 |
+
load_path (str): Path for pretrained SPyNet. Default: None.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, load_path=None):
|
60 |
+
super(SPyNetTOF, self).__init__()
|
61 |
+
|
62 |
+
self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
|
63 |
+
if load_path:
|
64 |
+
self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
|
65 |
+
|
66 |
+
def forward(self, ref, supp):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
ref (Tensor): Reference image with shape of (b, 3, h, w).
|
70 |
+
supp: The supporting image to be warped: (b, 3, h, w).
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
Tensor: Estimated optical flow: (b, 2, h, w).
|
74 |
+
"""
|
75 |
+
num_batches, _, h, w = ref.size()
|
76 |
+
ref = [ref]
|
77 |
+
supp = [supp]
|
78 |
+
|
79 |
+
# generate downsampled frames
|
80 |
+
for _ in range(3):
|
81 |
+
ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
|
82 |
+
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
|
83 |
+
|
84 |
+
# flow computation
|
85 |
+
flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
|
86 |
+
for i in range(4):
|
87 |
+
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
|
88 |
+
flow = flow_up + self.basic_module[i](
|
89 |
+
torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
|
90 |
+
return flow
|
91 |
+
|
92 |
+
|
93 |
+
@ARCH_REGISTRY.register()
|
94 |
+
class TOFlow(nn.Module):
|
95 |
+
"""PyTorch implementation of TOFlow.
|
96 |
+
|
97 |
+
In TOFlow, the LR frames are pre-upsampled and have the same size with
|
98 |
+
the GT frames.
|
99 |
+
Paper:
|
100 |
+
Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
|
101 |
+
Code reference:
|
102 |
+
1. https://github.com/anchen1011/toflow
|
103 |
+
2. https://github.com/Coldog2333/pytoflow
|
104 |
+
|
105 |
+
Args:
|
106 |
+
adapt_official_weights (bool): Whether to adapt the weights translated
|
107 |
+
from the official implementation. Set to false if you want to
|
108 |
+
train from scratch. Default: False
|
109 |
+
"""
|
110 |
+
|
111 |
+
def __init__(self, adapt_official_weights=False):
|
112 |
+
super(TOFlow, self).__init__()
|
113 |
+
self.adapt_official_weights = adapt_official_weights
|
114 |
+
self.ref_idx = 0 if adapt_official_weights else 3
|
115 |
+
|
116 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
117 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
118 |
+
|
119 |
+
# flow estimation module
|
120 |
+
self.spynet = SPyNetTOF()
|
121 |
+
|
122 |
+
# reconstruction module
|
123 |
+
self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
|
124 |
+
self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
|
125 |
+
self.conv_3 = nn.Conv2d(64, 64, 1)
|
126 |
+
self.conv_4 = nn.Conv2d(64, 3, 1)
|
127 |
+
|
128 |
+
# activation function
|
129 |
+
self.relu = nn.ReLU(inplace=True)
|
130 |
+
|
131 |
+
def normalize(self, img):
|
132 |
+
return (img - self.mean) / self.std
|
133 |
+
|
134 |
+
def denormalize(self, img):
|
135 |
+
return img * self.std + self.mean
|
136 |
+
|
137 |
+
def forward(self, lrs):
|
138 |
+
"""
|
139 |
+
Args:
|
140 |
+
lrs: Input lr frames: (b, 7, 3, h, w).
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tensor: SR frame: (b, 3, h, w).
|
144 |
+
"""
|
145 |
+
# In the official implementation, the 0-th frame is the reference frame
|
146 |
+
if self.adapt_official_weights:
|
147 |
+
lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
|
148 |
+
|
149 |
+
num_batches, num_lrs, _, h, w = lrs.size()
|
150 |
+
|
151 |
+
lrs = self.normalize(lrs.view(-1, 3, h, w))
|
152 |
+
lrs = lrs.view(num_batches, num_lrs, 3, h, w)
|
153 |
+
|
154 |
+
lr_ref = lrs[:, self.ref_idx, :, :, :]
|
155 |
+
lr_aligned = []
|
156 |
+
for i in range(7): # 7 frames
|
157 |
+
if i == self.ref_idx:
|
158 |
+
lr_aligned.append(lr_ref)
|
159 |
+
else:
|
160 |
+
lr_supp = lrs[:, i, :, :, :]
|
161 |
+
flow = self.spynet(lr_ref, lr_supp)
|
162 |
+
lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
|
163 |
+
|
164 |
+
# reconstruction
|
165 |
+
hr = torch.stack(lr_aligned, dim=1)
|
166 |
+
hr = hr.view(num_batches, -1, h, w)
|
167 |
+
hr = self.relu(self.conv_1(hr))
|
168 |
+
hr = self.relu(self.conv_2(hr))
|
169 |
+
hr = self.relu(self.conv_3(hr))
|
170 |
+
hr = self.conv_4(hr) + lr_ref
|
171 |
+
|
172 |
+
return self.denormalize(hr)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
from r_basicsr.utils.registry import ARCH_REGISTRY
|
8 |
+
|
9 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
10 |
+
NAMES = {
|
11 |
+
'vgg11': [
|
12 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
13 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
14 |
+
'pool5'
|
15 |
+
],
|
16 |
+
'vgg13': [
|
17 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
18 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
19 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
20 |
+
],
|
21 |
+
'vgg16': [
|
22 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
23 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
24 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
25 |
+
'pool5'
|
26 |
+
],
|
27 |
+
'vgg19': [
|
28 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
29 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
30 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
31 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
32 |
+
]
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
def insert_bn(names):
|
37 |
+
"""Insert bn layer after each conv.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
names (list): The list of layer names.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
list: The list of layer names with bn layers.
|
44 |
+
"""
|
45 |
+
names_bn = []
|
46 |
+
for name in names:
|
47 |
+
names_bn.append(name)
|
48 |
+
if 'conv' in name:
|
49 |
+
position = name.replace('conv', '')
|
50 |
+
names_bn.append('bn' + position)
|
51 |
+
return names_bn
|
52 |
+
|
53 |
+
|
54 |
+
@ARCH_REGISTRY.register()
|
55 |
+
class VGGFeatureExtractor(nn.Module):
|
56 |
+
"""VGG network for feature extraction.
|
57 |
+
|
58 |
+
In this implementation, we allow users to choose whether use normalization
|
59 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
60 |
+
path must fit the vgg type.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
64 |
+
features according to the layer_name_list.
|
65 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
66 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
67 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
68 |
+
the input feature must in the range [0, 1]. Default: True.
|
69 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
70 |
+
Default: False.
|
71 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
72 |
+
optimized. Default: False.
|
73 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
74 |
+
will be removed. Default: False.
|
75 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
layer_name_list,
|
80 |
+
vgg_type='vgg19',
|
81 |
+
use_input_norm=True,
|
82 |
+
range_norm=False,
|
83 |
+
requires_grad=False,
|
84 |
+
remove_pooling=False,
|
85 |
+
pooling_stride=2):
|
86 |
+
super(VGGFeatureExtractor, self).__init__()
|
87 |
+
|
88 |
+
self.layer_name_list = layer_name_list
|
89 |
+
self.use_input_norm = use_input_norm
|
90 |
+
self.range_norm = range_norm
|
91 |
+
|
92 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
93 |
+
if 'bn' in vgg_type:
|
94 |
+
self.names = insert_bn(self.names)
|
95 |
+
|
96 |
+
# only borrow layers that will be used to avoid unused params
|
97 |
+
max_idx = 0
|
98 |
+
for v in layer_name_list:
|
99 |
+
idx = self.names.index(v)
|
100 |
+
if idx > max_idx:
|
101 |
+
max_idx = idx
|
102 |
+
|
103 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
105 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
106 |
+
vgg_net.load_state_dict(state_dict)
|
107 |
+
else:
|
108 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
109 |
+
|
110 |
+
features = vgg_net.features[:max_idx + 1]
|
111 |
+
|
112 |
+
modified_net = OrderedDict()
|
113 |
+
for k, v in zip(self.names, features):
|
114 |
+
if 'pool' in k:
|
115 |
+
# if remove_pooling is true, pooling operation will be removed
|
116 |
+
if remove_pooling:
|
117 |
+
continue
|
118 |
+
else:
|
119 |
+
# in some cases, we may want to change the default stride
|
120 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
121 |
+
else:
|
122 |
+
modified_net[k] = v
|
123 |
+
|
124 |
+
self.vgg_net = nn.Sequential(modified_net)
|
125 |
+
|
126 |
+
if not requires_grad:
|
127 |
+
self.vgg_net.eval()
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = False
|
130 |
+
else:
|
131 |
+
self.vgg_net.train()
|
132 |
+
for param in self.parameters():
|
133 |
+
param.requires_grad = True
|
134 |
+
|
135 |
+
if self.use_input_norm:
|
136 |
+
# the mean is for image with range [0, 1]
|
137 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
138 |
+
# the std is for image with range [0, 1]
|
139 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
140 |
+
|
141 |
+
def forward(self, x):
|
142 |
+
"""Forward function.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Tensor: Forward results.
|
149 |
+
"""
|
150 |
+
if self.range_norm:
|
151 |
+
x = (x + 1) / 2
|
152 |
+
if self.use_input_norm:
|
153 |
+
x = (x - self.mean) / self.std
|
154 |
+
|
155 |
+
output = {}
|
156 |
+
for key, layer in self.vgg_net._modules.items():
|
157 |
+
x = layer(x)
|
158 |
+
if key in self.layer_name_list:
|
159 |
+
output[key] = x.clone()
|
160 |
+
|
161 |
+
return output
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import torch.utils.data
|
6 |
+
from copy import deepcopy
|
7 |
+
from functools import partial
|
8 |
+
from os import path as osp
|
9 |
+
|
10 |
+
from r_basicsr.data.prefetch_dataloader import PrefetchDataLoader
|
11 |
+
from r_basicsr.utils import get_root_logger, scandir
|
12 |
+
from r_basicsr.utils.dist_util import get_dist_info
|
13 |
+
from r_basicsr.utils.registry import DATASET_REGISTRY
|
14 |
+
|
15 |
+
__all__ = ['build_dataset', 'build_dataloader']
|
16 |
+
|
17 |
+
# automatically scan and import dataset modules for registry
|
18 |
+
# scan all the files under the data folder with '_dataset' in file names
|
19 |
+
data_folder = osp.dirname(osp.abspath(__file__))
|
20 |
+
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
|
21 |
+
# import all the dataset modules
|
22 |
+
_dataset_modules = [importlib.import_module(f'r_basicsr.data.{file_name}') for file_name in dataset_filenames]
|
23 |
+
|
24 |
+
|
25 |
+
def build_dataset(dataset_opt):
|
26 |
+
"""Build dataset from options.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dataset_opt (dict): Configuration for dataset. It must contain:
|
30 |
+
name (str): Dataset name.
|
31 |
+
type (str): Dataset type.
|
32 |
+
"""
|
33 |
+
dataset_opt = deepcopy(dataset_opt)
|
34 |
+
dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
|
35 |
+
logger = get_root_logger()
|
36 |
+
logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
|
37 |
+
return dataset
|
38 |
+
|
39 |
+
|
40 |
+
def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
|
41 |
+
"""Build dataloader.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
dataset (torch.utils.data.Dataset): Dataset.
|
45 |
+
dataset_opt (dict): Dataset options. It contains the following keys:
|
46 |
+
phase (str): 'train' or 'val'.
|
47 |
+
num_worker_per_gpu (int): Number of workers for each GPU.
|
48 |
+
batch_size_per_gpu (int): Training batch size for each GPU.
|
49 |
+
num_gpu (int): Number of GPUs. Used only in the train phase.
|
50 |
+
Default: 1.
|
51 |
+
dist (bool): Whether in distributed training. Used only in the train
|
52 |
+
phase. Default: False.
|
53 |
+
sampler (torch.utils.data.sampler): Data sampler. Default: None.
|
54 |
+
seed (int | None): Seed. Default: None
|
55 |
+
"""
|
56 |
+
phase = dataset_opt['phase']
|
57 |
+
rank, _ = get_dist_info()
|
58 |
+
if phase == 'train':
|
59 |
+
if dist: # distributed training
|
60 |
+
batch_size = dataset_opt['batch_size_per_gpu']
|
61 |
+
num_workers = dataset_opt['num_worker_per_gpu']
|
62 |
+
else: # non-distributed training
|
63 |
+
multiplier = 1 if num_gpu == 0 else num_gpu
|
64 |
+
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
|
65 |
+
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
|
66 |
+
dataloader_args = dict(
|
67 |
+
dataset=dataset,
|
68 |
+
batch_size=batch_size,
|
69 |
+
shuffle=False,
|
70 |
+
num_workers=num_workers,
|
71 |
+
sampler=sampler,
|
72 |
+
drop_last=True)
|
73 |
+
if sampler is None:
|
74 |
+
dataloader_args['shuffle'] = True
|
75 |
+
dataloader_args['worker_init_fn'] = partial(
|
76 |
+
worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
|
77 |
+
elif phase in ['val', 'test']: # validation
|
78 |
+
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
|
81 |
+
|
82 |
+
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
|
83 |
+
dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
|
84 |
+
|
85 |
+
prefetch_mode = dataset_opt.get('prefetch_mode')
|
86 |
+
if prefetch_mode == 'cpu': # CPUPrefetcher
|
87 |
+
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
|
88 |
+
logger = get_root_logger()
|
89 |
+
logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
|
90 |
+
return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
|
91 |
+
else:
|
92 |
+
# prefetch_mode=None: Normal dataloader
|
93 |
+
# prefetch_mode='cuda': dataloader for CUDAPrefetcher
|
94 |
+
return torch.utils.data.DataLoader(**dataloader_args)
|
95 |
+
|
96 |
+
|
97 |
+
def worker_init_fn(worker_id, num_workers, rank, seed):
|
98 |
+
# Set the worker seed to num_workers * rank + worker_id + seed
|
99 |
+
worker_seed = num_workers * rank + worker_id + seed
|
100 |
+
np.random.seed(worker_seed)
|
101 |
+
random.seed(worker_seed)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EnlargedSampler(Sampler):
|
7 |
+
"""Sampler that restricts data loading to a subset of the dataset.
|
8 |
+
|
9 |
+
Modified from torch.utils.data.distributed.DistributedSampler
|
10 |
+
Support enlarging the dataset for iteration-based training, for saving
|
11 |
+
time when restart the dataloader after each epoch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
dataset (torch.utils.data.Dataset): Dataset used for sampling.
|
15 |
+
num_replicas (int | None): Number of processes participating in
|
16 |
+
the training. It is usually the world_size.
|
17 |
+
rank (int | None): Rank of the current process within num_replicas.
|
18 |
+
ratio (int): Enlarging ratio. Default: 1.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, dataset, num_replicas, rank, ratio=1):
|
22 |
+
self.dataset = dataset
|
23 |
+
self.num_replicas = num_replicas
|
24 |
+
self.rank = rank
|
25 |
+
self.epoch = 0
|
26 |
+
self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
|
27 |
+
self.total_size = self.num_samples * self.num_replicas
|
28 |
+
|
29 |
+
def __iter__(self):
|
30 |
+
# deterministically shuffle based on epoch
|
31 |
+
g = torch.Generator()
|
32 |
+
g.manual_seed(self.epoch)
|
33 |
+
indices = torch.randperm(self.total_size, generator=g).tolist()
|
34 |
+
|
35 |
+
dataset_size = len(self.dataset)
|
36 |
+
indices = [v % dataset_size for v in indices]
|
37 |
+
|
38 |
+
# subsample
|
39 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
40 |
+
assert len(indices) == self.num_samples
|
41 |
+
|
42 |
+
return iter(indices)
|
43 |
+
|
44 |
+
def __len__(self):
|
45 |
+
return self.num_samples
|
46 |
+
|
47 |
+
def set_epoch(self, epoch):
|
48 |
+
self.epoch = epoch
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from os import path as osp
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from r_basicsr.data.transforms import mod_crop
|
8 |
+
from r_basicsr.utils import img2tensor, scandir
|
9 |
+
|
10 |
+
|
11 |
+
def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
|
12 |
+
"""Read a sequence of images from a given folder path.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
path (list[str] | str): List of image paths or image folder path.
|
16 |
+
require_mod_crop (bool): Require mod crop for each image.
|
17 |
+
Default: False.
|
18 |
+
scale (int): Scale factor for mod_crop. Default: 1.
|
19 |
+
return_imgname(bool): Whether return image names. Default False.
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Tensor: size (t, c, h, w), RGB, [0, 1].
|
23 |
+
list[str]: Returned image name list.
|
24 |
+
"""
|
25 |
+
if isinstance(path, list):
|
26 |
+
img_paths = path
|
27 |
+
else:
|
28 |
+
img_paths = sorted(list(scandir(path, full_path=True)))
|
29 |
+
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
30 |
+
|
31 |
+
if require_mod_crop:
|
32 |
+
imgs = [mod_crop(img, scale) for img in imgs]
|
33 |
+
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
34 |
+
imgs = torch.stack(imgs, dim=0)
|
35 |
+
|
36 |
+
if return_imgname:
|
37 |
+
imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
|
38 |
+
return imgs, imgnames
|
39 |
+
else:
|
40 |
+
return imgs
|
41 |
+
|
42 |
+
|
43 |
+
def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
|
44 |
+
"""Generate an index list for reading `num_frames` frames from a sequence
|
45 |
+
of images.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
crt_idx (int): Current center index.
|
49 |
+
max_frame_num (int): Max number of the sequence of images (from 1).
|
50 |
+
num_frames (int): Reading num_frames frames.
|
51 |
+
padding (str): Padding mode, one of
|
52 |
+
'replicate' | 'reflection' | 'reflection_circle' | 'circle'
|
53 |
+
Examples: current_idx = 0, num_frames = 5
|
54 |
+
The generated frame indices under different padding mode:
|
55 |
+
replicate: [0, 0, 0, 1, 2]
|
56 |
+
reflection: [2, 1, 0, 1, 2]
|
57 |
+
reflection_circle: [4, 3, 0, 1, 2]
|
58 |
+
circle: [3, 4, 0, 1, 2]
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
list[int]: A list of indices.
|
62 |
+
"""
|
63 |
+
assert num_frames % 2 == 1, 'num_frames should be an odd number.'
|
64 |
+
assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
|
65 |
+
|
66 |
+
max_frame_num = max_frame_num - 1 # start from 0
|
67 |
+
num_pad = num_frames // 2
|
68 |
+
|
69 |
+
indices = []
|
70 |
+
for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
|
71 |
+
if i < 0:
|
72 |
+
if padding == 'replicate':
|
73 |
+
pad_idx = 0
|
74 |
+
elif padding == 'reflection':
|
75 |
+
pad_idx = -i
|
76 |
+
elif padding == 'reflection_circle':
|
77 |
+
pad_idx = crt_idx + num_pad - i
|
78 |
+
else:
|
79 |
+
pad_idx = num_frames + i
|
80 |
+
elif i > max_frame_num:
|
81 |
+
if padding == 'replicate':
|
82 |
+
pad_idx = max_frame_num
|
83 |
+
elif padding == 'reflection':
|
84 |
+
pad_idx = max_frame_num * 2 - i
|
85 |
+
elif padding == 'reflection_circle':
|
86 |
+
pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
|
87 |
+
else:
|
88 |
+
pad_idx = i - num_frames
|
89 |
+
else:
|
90 |
+
pad_idx = i
|
91 |
+
indices.append(pad_idx)
|
92 |
+
return indices
|
93 |
+
|
94 |
+
|
95 |
+
def paired_paths_from_lmdb(folders, keys):
|
96 |
+
"""Generate paired paths from lmdb files.
|
97 |
+
|
98 |
+
Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
|
99 |
+
|
100 |
+
lq.lmdb
|
101 |
+
├── data.mdb
|
102 |
+
├── lock.mdb
|
103 |
+
├── meta_info.txt
|
104 |
+
|
105 |
+
The data.mdb and lock.mdb are standard lmdb files and you can refer to
|
106 |
+
https://lmdb.readthedocs.io/en/release/ for more details.
|
107 |
+
|
108 |
+
The meta_info.txt is a specified txt file to record the meta information
|
109 |
+
of our datasets. It will be automatically created when preparing
|
110 |
+
datasets by our provided dataset tools.
|
111 |
+
Each line in the txt file records
|
112 |
+
1)image name (with extension),
|
113 |
+
2)image shape,
|
114 |
+
3)compression level, separated by a white space.
|
115 |
+
Example: `baboon.png (120,125,3) 1`
|
116 |
+
|
117 |
+
We use the image name without extension as the lmdb key.
|
118 |
+
Note that we use the same key for the corresponding lq and gt images.
|
119 |
+
|
120 |
+
Args:
|
121 |
+
folders (list[str]): A list of folder path. The order of list should
|
122 |
+
be [input_folder, gt_folder].
|
123 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
124 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
125 |
+
Note that this key is different from lmdb keys.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
list[str]: Returned path list.
|
129 |
+
"""
|
130 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
131 |
+
f'But got {len(folders)}')
|
132 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
133 |
+
input_folder, gt_folder = folders
|
134 |
+
input_key, gt_key = keys
|
135 |
+
|
136 |
+
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
|
137 |
+
raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
|
138 |
+
f'formats. But received {input_key}: {input_folder}; '
|
139 |
+
f'{gt_key}: {gt_folder}')
|
140 |
+
# ensure that the two meta_info files are the same
|
141 |
+
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
|
142 |
+
input_lmdb_keys = [line.split('.')[0] for line in fin]
|
143 |
+
with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
|
144 |
+
gt_lmdb_keys = [line.split('.')[0] for line in fin]
|
145 |
+
if set(input_lmdb_keys) != set(gt_lmdb_keys):
|
146 |
+
raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
|
147 |
+
else:
|
148 |
+
paths = []
|
149 |
+
for lmdb_key in sorted(input_lmdb_keys):
|
150 |
+
paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
|
151 |
+
return paths
|
152 |
+
|
153 |
+
|
154 |
+
def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
|
155 |
+
"""Generate paired paths from an meta information file.
|
156 |
+
|
157 |
+
Each line in the meta information file contains the image names and
|
158 |
+
image shape (usually for gt), separated by a white space.
|
159 |
+
|
160 |
+
Example of an meta information file:
|
161 |
+
```
|
162 |
+
0001_s001.png (480,480,3)
|
163 |
+
0001_s002.png (480,480,3)
|
164 |
+
```
|
165 |
+
|
166 |
+
Args:
|
167 |
+
folders (list[str]): A list of folder path. The order of list should
|
168 |
+
be [input_folder, gt_folder].
|
169 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
170 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
171 |
+
meta_info_file (str): Path to the meta information file.
|
172 |
+
filename_tmpl (str): Template for each filename. Note that the
|
173 |
+
template excludes the file extension. Usually the filename_tmpl is
|
174 |
+
for files in the input folder.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
list[str]: Returned path list.
|
178 |
+
"""
|
179 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
180 |
+
f'But got {len(folders)}')
|
181 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
182 |
+
input_folder, gt_folder = folders
|
183 |
+
input_key, gt_key = keys
|
184 |
+
|
185 |
+
with open(meta_info_file, 'r') as fin:
|
186 |
+
gt_names = [line.strip().split(' ')[0] for line in fin]
|
187 |
+
|
188 |
+
paths = []
|
189 |
+
for gt_name in gt_names:
|
190 |
+
basename, ext = osp.splitext(osp.basename(gt_name))
|
191 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
192 |
+
input_path = osp.join(input_folder, input_name)
|
193 |
+
gt_path = osp.join(gt_folder, gt_name)
|
194 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
195 |
+
return paths
|
196 |
+
|
197 |
+
|
198 |
+
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
199 |
+
"""Generate paired paths from folders.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
folders (list[str]): A list of folder path. The order of list should
|
203 |
+
be [input_folder, gt_folder].
|
204 |
+
keys (list[str]): A list of keys identifying folders. The order should
|
205 |
+
be in consistent with folders, e.g., ['lq', 'gt'].
|
206 |
+
filename_tmpl (str): Template for each filename. Note that the
|
207 |
+
template excludes the file extension. Usually the filename_tmpl is
|
208 |
+
for files in the input folder.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
list[str]: Returned path list.
|
212 |
+
"""
|
213 |
+
assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
|
214 |
+
f'But got {len(folders)}')
|
215 |
+
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
|
216 |
+
input_folder, gt_folder = folders
|
217 |
+
input_key, gt_key = keys
|
218 |
+
|
219 |
+
input_paths = list(scandir(input_folder))
|
220 |
+
gt_paths = list(scandir(gt_folder))
|
221 |
+
assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
|
222 |
+
f'{len(input_paths)}, {len(gt_paths)}.')
|
223 |
+
paths = []
|
224 |
+
for gt_path in gt_paths:
|
225 |
+
basename, ext = osp.splitext(osp.basename(gt_path))
|
226 |
+
input_name = f'{filename_tmpl.format(basename)}{ext}'
|
227 |
+
input_path = osp.join(input_folder, input_name)
|
228 |
+
assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
|
229 |
+
gt_path = osp.join(gt_folder, gt_path)
|
230 |
+
paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
|
231 |
+
return paths
|
232 |
+
|
233 |
+
|
234 |
+
def paths_from_folder(folder):
|
235 |
+
"""Generate paths from folder.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
folder (str): Folder path.
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
list[str]: Returned path list.
|
242 |
+
"""
|
243 |
+
|
244 |
+
paths = list(scandir(folder))
|
245 |
+
paths = [osp.join(folder, path) for path in paths]
|
246 |
+
return paths
|
247 |
+
|
248 |
+
|
249 |
+
def paths_from_lmdb(folder):
|
250 |
+
"""Generate paths from lmdb.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
folder (str): Folder path.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
list[str]: Returned path list.
|
257 |
+
"""
|
258 |
+
if not folder.endswith('.lmdb'):
|
259 |
+
raise ValueError(f'Folder {folder}folder should in lmdb format.')
|
260 |
+
with open(osp.join(folder, 'meta_info.txt')) as fin:
|
261 |
+
paths = [line.split('.')[0] for line in fin]
|
262 |
+
return paths
|
263 |
+
|
264 |
+
|
265 |
+
def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
|
266 |
+
"""Generate Gaussian kernel used in `duf_downsample`.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
kernel_size (int): Kernel size. Default: 13.
|
270 |
+
sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
|
271 |
+
|
272 |
+
Returns:
|
273 |
+
np.array: The Gaussian kernel.
|
274 |
+
"""
|
275 |
+
from scipy.ndimage import filters as filters
|
276 |
+
kernel = np.zeros((kernel_size, kernel_size))
|
277 |
+
# set element at the middle to one, a dirac delta
|
278 |
+
kernel[kernel_size // 2, kernel_size // 2] = 1
|
279 |
+
# gaussian-smooth the dirac, resulting in a gaussian filter
|
280 |
+
return filters.gaussian_filter(kernel, sigma)
|
281 |
+
|
282 |
+
|
283 |
+
def duf_downsample(x, kernel_size=13, scale=4):
|
284 |
+
"""Downsamping with Gaussian kernel used in the DUF official code.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
|
288 |
+
kernel_size (int): Kernel size. Default: 13.
|
289 |
+
scale (int): Downsampling factor. Supported scale: (2, 3, 4).
|
290 |
+
Default: 4.
|
291 |
+
|
292 |
+
Returns:
|
293 |
+
Tensor: DUF downsampled frames.
|
294 |
+
"""
|
295 |
+
assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
|
296 |
+
|
297 |
+
squeeze_flag = False
|
298 |
+
if x.ndim == 4:
|
299 |
+
squeeze_flag = True
|
300 |
+
x = x.unsqueeze(0)
|
301 |
+
b, t, c, h, w = x.size()
|
302 |
+
x = x.view(-1, 1, h, w)
|
303 |
+
pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
|
304 |
+
x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
|
305 |
+
|
306 |
+
gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
|
307 |
+
gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
|
308 |
+
x = F.conv2d(x, gaussian_filter, stride=scale)
|
309 |
+
x = x[:, :, 2:-2, 2:-2]
|
310 |
+
x = x.view(b, t, c, x.size(2), x.size(3))
|
311 |
+
if squeeze_flag:
|
312 |
+
x = x.squeeze(0)
|
313 |
+
return x
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py
ADDED
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from scipy import special
|
7 |
+
from scipy.stats import multivariate_normal
|
8 |
+
try:
|
9 |
+
from torchvision.transforms.functional_tensor import rgb_to_grayscale
|
10 |
+
except:
|
11 |
+
from torchvision.transforms.functional import rgb_to_grayscale
|
12 |
+
|
13 |
+
# -------------------------------------------------------------------- #
|
14 |
+
# --------------------------- blur kernels --------------------------- #
|
15 |
+
# -------------------------------------------------------------------- #
|
16 |
+
|
17 |
+
|
18 |
+
# --------------------------- util functions --------------------------- #
|
19 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
20 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
21 |
+
|
22 |
+
Args:
|
23 |
+
sig_x (float):
|
24 |
+
sig_y (float):
|
25 |
+
theta (float): Radian measurement.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
ndarray: Rotated sigma matrix.
|
29 |
+
"""
|
30 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
31 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
32 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
33 |
+
|
34 |
+
|
35 |
+
def mesh_grid(kernel_size):
|
36 |
+
"""Generate the mesh grid, centering at zero.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
kernel_size (int):
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
43 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
44 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
45 |
+
"""
|
46 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
47 |
+
xx, yy = np.meshgrid(ax, ax)
|
48 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
49 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
50 |
+
return xy, xx, yy
|
51 |
+
|
52 |
+
|
53 |
+
def pdf2(sigma_matrix, grid):
|
54 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
58 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
59 |
+
with the shape (K, K, 2), K is the kernel size.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
kernel (ndarrray): un-normalized kernel.
|
63 |
+
"""
|
64 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
65 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
66 |
+
return kernel
|
67 |
+
|
68 |
+
|
69 |
+
def cdf2(d_matrix, grid):
|
70 |
+
"""Calculate the CDF of the standard bivariate Gaussian distribution.
|
71 |
+
Used in skewed Gaussian distribution.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
d_matrix (ndarrasy): skew matrix.
|
75 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
76 |
+
with the shape (K, K, 2), K is the kernel size.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
cdf (ndarray): skewed cdf.
|
80 |
+
"""
|
81 |
+
rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
|
82 |
+
grid = np.dot(grid, d_matrix)
|
83 |
+
cdf = rv.cdf(grid)
|
84 |
+
return cdf
|
85 |
+
|
86 |
+
|
87 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
88 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
89 |
+
|
90 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
kernel_size (int):
|
94 |
+
sig_x (float):
|
95 |
+
sig_y (float):
|
96 |
+
theta (float): Radian measurement.
|
97 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
98 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
99 |
+
isotropic (bool):
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
kernel (ndarray): normalized kernel.
|
103 |
+
"""
|
104 |
+
if grid is None:
|
105 |
+
grid, _, _ = mesh_grid(kernel_size)
|
106 |
+
if isotropic:
|
107 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
108 |
+
else:
|
109 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
110 |
+
kernel = pdf2(sigma_matrix, grid)
|
111 |
+
kernel = kernel / np.sum(kernel)
|
112 |
+
return kernel
|
113 |
+
|
114 |
+
|
115 |
+
def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
116 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
117 |
+
Described in `Parameter Estimation For Multivariate Generalized
|
118 |
+
Gaussian Distributions`_
|
119 |
+
by Pascal et. al (2013).
|
120 |
+
|
121 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
kernel_size (int):
|
125 |
+
sig_x (float):
|
126 |
+
sig_y (float):
|
127 |
+
theta (float): Radian measurement.
|
128 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
129 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
130 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
kernel (ndarray): normalized kernel.
|
134 |
+
|
135 |
+
.. _Parameter Estimation For Multivariate Generalized Gaussian
|
136 |
+
Distributions: https://arxiv.org/abs/1302.6498
|
137 |
+
"""
|
138 |
+
if grid is None:
|
139 |
+
grid, _, _ = mesh_grid(kernel_size)
|
140 |
+
if isotropic:
|
141 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
142 |
+
else:
|
143 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
144 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
145 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
146 |
+
kernel = kernel / np.sum(kernel)
|
147 |
+
return kernel
|
148 |
+
|
149 |
+
|
150 |
+
def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
|
151 |
+
"""Generate a plateau-like anisotropic kernel.
|
152 |
+
1 / (1+x^(beta))
|
153 |
+
|
154 |
+
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
|
155 |
+
|
156 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
kernel_size (int):
|
160 |
+
sig_x (float):
|
161 |
+
sig_y (float):
|
162 |
+
theta (float): Radian measurement.
|
163 |
+
beta (float): shape parameter, beta = 1 is the normal distribution.
|
164 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
165 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
kernel (ndarray): normalized kernel.
|
169 |
+
"""
|
170 |
+
if grid is None:
|
171 |
+
grid, _, _ = mesh_grid(kernel_size)
|
172 |
+
if isotropic:
|
173 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
174 |
+
else:
|
175 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
176 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
177 |
+
kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
|
178 |
+
kernel = kernel / np.sum(kernel)
|
179 |
+
return kernel
|
180 |
+
|
181 |
+
|
182 |
+
def random_bivariate_Gaussian(kernel_size,
|
183 |
+
sigma_x_range,
|
184 |
+
sigma_y_range,
|
185 |
+
rotation_range,
|
186 |
+
noise_range=None,
|
187 |
+
isotropic=True):
|
188 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
189 |
+
|
190 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
kernel_size (int):
|
194 |
+
sigma_x_range (tuple): [0.6, 5]
|
195 |
+
sigma_y_range (tuple): [0.6, 5]
|
196 |
+
rotation range (tuple): [-math.pi, math.pi]
|
197 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
198 |
+
[0.75, 1.25]. Default: None
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
kernel (ndarray):
|
202 |
+
"""
|
203 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
204 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
205 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
206 |
+
if isotropic is False:
|
207 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
208 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
209 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
210 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
211 |
+
else:
|
212 |
+
sigma_y = sigma_x
|
213 |
+
rotation = 0
|
214 |
+
|
215 |
+
kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
|
216 |
+
|
217 |
+
# add multiplicative noise
|
218 |
+
if noise_range is not None:
|
219 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
220 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
221 |
+
kernel = kernel * noise
|
222 |
+
kernel = kernel / np.sum(kernel)
|
223 |
+
return kernel
|
224 |
+
|
225 |
+
|
226 |
+
def random_bivariate_generalized_Gaussian(kernel_size,
|
227 |
+
sigma_x_range,
|
228 |
+
sigma_y_range,
|
229 |
+
rotation_range,
|
230 |
+
beta_range,
|
231 |
+
noise_range=None,
|
232 |
+
isotropic=True):
|
233 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
234 |
+
|
235 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
kernel_size (int):
|
239 |
+
sigma_x_range (tuple): [0.6, 5]
|
240 |
+
sigma_y_range (tuple): [0.6, 5]
|
241 |
+
rotation range (tuple): [-math.pi, math.pi]
|
242 |
+
beta_range (tuple): [0.5, 8]
|
243 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
244 |
+
[0.75, 1.25]. Default: None
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
kernel (ndarray):
|
248 |
+
"""
|
249 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
250 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
251 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
252 |
+
if isotropic is False:
|
253 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
254 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
255 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
256 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
257 |
+
else:
|
258 |
+
sigma_y = sigma_x
|
259 |
+
rotation = 0
|
260 |
+
|
261 |
+
# assume beta_range[0] < 1 < beta_range[1]
|
262 |
+
if np.random.uniform() < 0.5:
|
263 |
+
beta = np.random.uniform(beta_range[0], 1)
|
264 |
+
else:
|
265 |
+
beta = np.random.uniform(1, beta_range[1])
|
266 |
+
|
267 |
+
kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
268 |
+
|
269 |
+
# add multiplicative noise
|
270 |
+
if noise_range is not None:
|
271 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
272 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
273 |
+
kernel = kernel * noise
|
274 |
+
kernel = kernel / np.sum(kernel)
|
275 |
+
return kernel
|
276 |
+
|
277 |
+
|
278 |
+
def random_bivariate_plateau(kernel_size,
|
279 |
+
sigma_x_range,
|
280 |
+
sigma_y_range,
|
281 |
+
rotation_range,
|
282 |
+
beta_range,
|
283 |
+
noise_range=None,
|
284 |
+
isotropic=True):
|
285 |
+
"""Randomly generate bivariate plateau kernels.
|
286 |
+
|
287 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
kernel_size (int):
|
291 |
+
sigma_x_range (tuple): [0.6, 5]
|
292 |
+
sigma_y_range (tuple): [0.6, 5]
|
293 |
+
rotation range (tuple): [-math.pi/2, math.pi/2]
|
294 |
+
beta_range (tuple): [1, 4]
|
295 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
296 |
+
[0.75, 1.25]. Default: None
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
kernel (ndarray):
|
300 |
+
"""
|
301 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
302 |
+
assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
|
303 |
+
sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
|
304 |
+
if isotropic is False:
|
305 |
+
assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
|
306 |
+
assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
|
307 |
+
sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
|
308 |
+
rotation = np.random.uniform(rotation_range[0], rotation_range[1])
|
309 |
+
else:
|
310 |
+
sigma_y = sigma_x
|
311 |
+
rotation = 0
|
312 |
+
|
313 |
+
# TODO: this may be not proper
|
314 |
+
if np.random.uniform() < 0.5:
|
315 |
+
beta = np.random.uniform(beta_range[0], 1)
|
316 |
+
else:
|
317 |
+
beta = np.random.uniform(1, beta_range[1])
|
318 |
+
|
319 |
+
kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
|
320 |
+
# add multiplicative noise
|
321 |
+
if noise_range is not None:
|
322 |
+
assert noise_range[0] < noise_range[1], 'Wrong noise range.'
|
323 |
+
noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
|
324 |
+
kernel = kernel * noise
|
325 |
+
kernel = kernel / np.sum(kernel)
|
326 |
+
|
327 |
+
return kernel
|
328 |
+
|
329 |
+
|
330 |
+
def random_mixed_kernels(kernel_list,
|
331 |
+
kernel_prob,
|
332 |
+
kernel_size=21,
|
333 |
+
sigma_x_range=(0.6, 5),
|
334 |
+
sigma_y_range=(0.6, 5),
|
335 |
+
rotation_range=(-math.pi, math.pi),
|
336 |
+
betag_range=(0.5, 8),
|
337 |
+
betap_range=(0.5, 8),
|
338 |
+
noise_range=None):
|
339 |
+
"""Randomly generate mixed kernels.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
kernel_list (tuple): a list name of kernel types,
|
343 |
+
support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
|
344 |
+
'plateau_aniso']
|
345 |
+
kernel_prob (tuple): corresponding kernel probability for each
|
346 |
+
kernel type
|
347 |
+
kernel_size (int):
|
348 |
+
sigma_x_range (tuple): [0.6, 5]
|
349 |
+
sigma_y_range (tuple): [0.6, 5]
|
350 |
+
rotation range (tuple): [-math.pi, math.pi]
|
351 |
+
beta_range (tuple): [0.5, 8]
|
352 |
+
noise_range(tuple, optional): multiplicative kernel noise,
|
353 |
+
[0.75, 1.25]. Default: None
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
kernel (ndarray):
|
357 |
+
"""
|
358 |
+
kernel_type = random.choices(kernel_list, kernel_prob)[0]
|
359 |
+
if kernel_type == 'iso':
|
360 |
+
kernel = random_bivariate_Gaussian(
|
361 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
|
362 |
+
elif kernel_type == 'aniso':
|
363 |
+
kernel = random_bivariate_Gaussian(
|
364 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
|
365 |
+
elif kernel_type == 'generalized_iso':
|
366 |
+
kernel = random_bivariate_generalized_Gaussian(
|
367 |
+
kernel_size,
|
368 |
+
sigma_x_range,
|
369 |
+
sigma_y_range,
|
370 |
+
rotation_range,
|
371 |
+
betag_range,
|
372 |
+
noise_range=noise_range,
|
373 |
+
isotropic=True)
|
374 |
+
elif kernel_type == 'generalized_aniso':
|
375 |
+
kernel = random_bivariate_generalized_Gaussian(
|
376 |
+
kernel_size,
|
377 |
+
sigma_x_range,
|
378 |
+
sigma_y_range,
|
379 |
+
rotation_range,
|
380 |
+
betag_range,
|
381 |
+
noise_range=noise_range,
|
382 |
+
isotropic=False)
|
383 |
+
elif kernel_type == 'plateau_iso':
|
384 |
+
kernel = random_bivariate_plateau(
|
385 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
|
386 |
+
elif kernel_type == 'plateau_aniso':
|
387 |
+
kernel = random_bivariate_plateau(
|
388 |
+
kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
|
389 |
+
return kernel
|
390 |
+
|
391 |
+
|
392 |
+
np.seterr(divide='ignore', invalid='ignore')
|
393 |
+
|
394 |
+
|
395 |
+
def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
|
396 |
+
"""2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
397 |
+
|
398 |
+
Args:
|
399 |
+
cutoff (float): cutoff frequency in radians (pi is max)
|
400 |
+
kernel_size (int): horizontal and vertical size, must be odd.
|
401 |
+
pad_to (int): pad kernel size to desired size, must be odd or zero.
|
402 |
+
"""
|
403 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
404 |
+
kernel = np.fromfunction(
|
405 |
+
lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
|
406 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
|
407 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
|
408 |
+
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
|
409 |
+
kernel = kernel / np.sum(kernel)
|
410 |
+
if pad_to > kernel_size:
|
411 |
+
pad_size = (pad_to - kernel_size) // 2
|
412 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
413 |
+
return kernel
|
414 |
+
|
415 |
+
|
416 |
+
# ------------------------------------------------------------- #
|
417 |
+
# --------------------------- noise --------------------------- #
|
418 |
+
# ------------------------------------------------------------- #
|
419 |
+
|
420 |
+
# ----------------------- Gaussian Noise ----------------------- #
|
421 |
+
|
422 |
+
|
423 |
+
def generate_gaussian_noise(img, sigma=10, gray_noise=False):
|
424 |
+
"""Generate Gaussian noise.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
428 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
429 |
+
|
430 |
+
Returns:
|
431 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
432 |
+
float32.
|
433 |
+
"""
|
434 |
+
if gray_noise:
|
435 |
+
noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
|
436 |
+
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
|
437 |
+
else:
|
438 |
+
noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
|
439 |
+
return noise
|
440 |
+
|
441 |
+
|
442 |
+
def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
|
443 |
+
"""Add Gaussian noise.
|
444 |
+
|
445 |
+
Args:
|
446 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
447 |
+
sigma (float): Noise scale (measured in range 255). Default: 10.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
451 |
+
float32.
|
452 |
+
"""
|
453 |
+
noise = generate_gaussian_noise(img, sigma, gray_noise)
|
454 |
+
out = img + noise
|
455 |
+
if clip and rounds:
|
456 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
457 |
+
elif clip:
|
458 |
+
out = np.clip(out, 0, 1)
|
459 |
+
elif rounds:
|
460 |
+
out = (out * 255.0).round() / 255.
|
461 |
+
return out
|
462 |
+
|
463 |
+
|
464 |
+
def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
|
465 |
+
"""Add Gaussian noise (PyTorch version).
|
466 |
+
|
467 |
+
Args:
|
468 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
469 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
473 |
+
float32.
|
474 |
+
"""
|
475 |
+
b, _, h, w = img.size()
|
476 |
+
if not isinstance(sigma, (float, int)):
|
477 |
+
sigma = sigma.view(img.size(0), 1, 1, 1)
|
478 |
+
if isinstance(gray_noise, (float, int)):
|
479 |
+
cal_gray_noise = gray_noise > 0
|
480 |
+
else:
|
481 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
482 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
483 |
+
|
484 |
+
if cal_gray_noise:
|
485 |
+
noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
|
486 |
+
noise_gray = noise_gray.view(b, 1, h, w)
|
487 |
+
|
488 |
+
# always calculate color noise
|
489 |
+
noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
|
490 |
+
|
491 |
+
if cal_gray_noise:
|
492 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
493 |
+
return noise
|
494 |
+
|
495 |
+
|
496 |
+
def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
|
497 |
+
"""Add Gaussian noise (PyTorch version).
|
498 |
+
|
499 |
+
Args:
|
500 |
+
img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
|
501 |
+
scale (float | Tensor): Noise scale. Default: 1.0.
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
505 |
+
float32.
|
506 |
+
"""
|
507 |
+
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
|
508 |
+
out = img + noise
|
509 |
+
if clip and rounds:
|
510 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
511 |
+
elif clip:
|
512 |
+
out = torch.clamp(out, 0, 1)
|
513 |
+
elif rounds:
|
514 |
+
out = (out * 255.0).round() / 255.
|
515 |
+
return out
|
516 |
+
|
517 |
+
|
518 |
+
# ----------------------- Random Gaussian Noise ----------------------- #
|
519 |
+
def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
|
520 |
+
sigma = np.random.uniform(sigma_range[0], sigma_range[1])
|
521 |
+
if np.random.uniform() < gray_prob:
|
522 |
+
gray_noise = True
|
523 |
+
else:
|
524 |
+
gray_noise = False
|
525 |
+
return generate_gaussian_noise(img, sigma, gray_noise)
|
526 |
+
|
527 |
+
|
528 |
+
def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
529 |
+
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
|
530 |
+
out = img + noise
|
531 |
+
if clip and rounds:
|
532 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
533 |
+
elif clip:
|
534 |
+
out = np.clip(out, 0, 1)
|
535 |
+
elif rounds:
|
536 |
+
out = (out * 255.0).round() / 255.
|
537 |
+
return out
|
538 |
+
|
539 |
+
|
540 |
+
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
|
541 |
+
sigma = torch.rand(
|
542 |
+
img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
|
543 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
544 |
+
gray_noise = (gray_noise < gray_prob).float()
|
545 |
+
return generate_gaussian_noise_pt(img, sigma, gray_noise)
|
546 |
+
|
547 |
+
|
548 |
+
def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
549 |
+
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
|
550 |
+
out = img + noise
|
551 |
+
if clip and rounds:
|
552 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
553 |
+
elif clip:
|
554 |
+
out = torch.clamp(out, 0, 1)
|
555 |
+
elif rounds:
|
556 |
+
out = (out * 255.0).round() / 255.
|
557 |
+
return out
|
558 |
+
|
559 |
+
|
560 |
+
# ----------------------- Poisson (Shot) Noise ----------------------- #
|
561 |
+
|
562 |
+
|
563 |
+
def generate_poisson_noise(img, scale=1.0, gray_noise=False):
|
564 |
+
"""Generate poisson noise.
|
565 |
+
|
566 |
+
Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
|
567 |
+
|
568 |
+
Args:
|
569 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
570 |
+
scale (float): Noise scale. Default: 1.0.
|
571 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
572 |
+
|
573 |
+
Returns:
|
574 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
575 |
+
float32.
|
576 |
+
"""
|
577 |
+
if gray_noise:
|
578 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
579 |
+
# round and clip image for counting vals correctly
|
580 |
+
img = np.clip((img * 255.0).round(), 0, 255) / 255.
|
581 |
+
vals = len(np.unique(img))
|
582 |
+
vals = 2**np.ceil(np.log2(vals))
|
583 |
+
out = np.float32(np.random.poisson(img * vals) / float(vals))
|
584 |
+
noise = out - img
|
585 |
+
if gray_noise:
|
586 |
+
noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
|
587 |
+
return noise * scale
|
588 |
+
|
589 |
+
|
590 |
+
def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
|
591 |
+
"""Add poisson noise.
|
592 |
+
|
593 |
+
Args:
|
594 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
595 |
+
scale (float): Noise scale. Default: 1.0.
|
596 |
+
gray_noise (bool): Whether generate gray noise. Default: False.
|
597 |
+
|
598 |
+
Returns:
|
599 |
+
(Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
|
600 |
+
float32.
|
601 |
+
"""
|
602 |
+
noise = generate_poisson_noise(img, scale, gray_noise)
|
603 |
+
out = img + noise
|
604 |
+
if clip and rounds:
|
605 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
606 |
+
elif clip:
|
607 |
+
out = np.clip(out, 0, 1)
|
608 |
+
elif rounds:
|
609 |
+
out = (out * 255.0).round() / 255.
|
610 |
+
return out
|
611 |
+
|
612 |
+
|
613 |
+
def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
|
614 |
+
"""Generate a batch of poisson noise (PyTorch version)
|
615 |
+
|
616 |
+
Args:
|
617 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
618 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
619 |
+
Default: 1.0.
|
620 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
621 |
+
0 for False, 1 for True. Default: 0.
|
622 |
+
|
623 |
+
Returns:
|
624 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
625 |
+
float32.
|
626 |
+
"""
|
627 |
+
b, _, h, w = img.size()
|
628 |
+
if isinstance(gray_noise, (float, int)):
|
629 |
+
cal_gray_noise = gray_noise > 0
|
630 |
+
else:
|
631 |
+
gray_noise = gray_noise.view(b, 1, 1, 1)
|
632 |
+
cal_gray_noise = torch.sum(gray_noise) > 0
|
633 |
+
if cal_gray_noise:
|
634 |
+
img_gray = rgb_to_grayscale(img, num_output_channels=1)
|
635 |
+
# round and clip image for counting vals correctly
|
636 |
+
img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
|
637 |
+
# use for-loop to get the unique values for each sample
|
638 |
+
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
|
639 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
640 |
+
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
|
641 |
+
out = torch.poisson(img_gray * vals) / vals
|
642 |
+
noise_gray = out - img_gray
|
643 |
+
noise_gray = noise_gray.expand(b, 3, h, w)
|
644 |
+
|
645 |
+
# always calculate color noise
|
646 |
+
# round and clip image for counting vals correctly
|
647 |
+
img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
|
648 |
+
# use for-loop to get the unique values for each sample
|
649 |
+
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
|
650 |
+
vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
|
651 |
+
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
|
652 |
+
out = torch.poisson(img * vals) / vals
|
653 |
+
noise = out - img
|
654 |
+
if cal_gray_noise:
|
655 |
+
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
|
656 |
+
if not isinstance(scale, (float, int)):
|
657 |
+
scale = scale.view(b, 1, 1, 1)
|
658 |
+
return noise * scale
|
659 |
+
|
660 |
+
|
661 |
+
def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
|
662 |
+
"""Add poisson noise to a batch of images (PyTorch version).
|
663 |
+
|
664 |
+
Args:
|
665 |
+
img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
|
666 |
+
scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
|
667 |
+
Default: 1.0.
|
668 |
+
gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
|
669 |
+
0 for False, 1 for True. Default: 0.
|
670 |
+
|
671 |
+
Returns:
|
672 |
+
(Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
|
673 |
+
float32.
|
674 |
+
"""
|
675 |
+
noise = generate_poisson_noise_pt(img, scale, gray_noise)
|
676 |
+
out = img + noise
|
677 |
+
if clip and rounds:
|
678 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
679 |
+
elif clip:
|
680 |
+
out = torch.clamp(out, 0, 1)
|
681 |
+
elif rounds:
|
682 |
+
out = (out * 255.0).round() / 255.
|
683 |
+
return out
|
684 |
+
|
685 |
+
|
686 |
+
# ----------------------- Random Poisson (Shot) Noise ----------------------- #
|
687 |
+
|
688 |
+
|
689 |
+
def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
|
690 |
+
scale = np.random.uniform(scale_range[0], scale_range[1])
|
691 |
+
if np.random.uniform() < gray_prob:
|
692 |
+
gray_noise = True
|
693 |
+
else:
|
694 |
+
gray_noise = False
|
695 |
+
return generate_poisson_noise(img, scale, gray_noise)
|
696 |
+
|
697 |
+
|
698 |
+
def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
699 |
+
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
|
700 |
+
out = img + noise
|
701 |
+
if clip and rounds:
|
702 |
+
out = np.clip((out * 255.0).round(), 0, 255) / 255.
|
703 |
+
elif clip:
|
704 |
+
out = np.clip(out, 0, 1)
|
705 |
+
elif rounds:
|
706 |
+
out = (out * 255.0).round() / 255.
|
707 |
+
return out
|
708 |
+
|
709 |
+
|
710 |
+
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
|
711 |
+
scale = torch.rand(
|
712 |
+
img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
|
713 |
+
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
|
714 |
+
gray_noise = (gray_noise < gray_prob).float()
|
715 |
+
return generate_poisson_noise_pt(img, scale, gray_noise)
|
716 |
+
|
717 |
+
|
718 |
+
def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
|
719 |
+
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
|
720 |
+
out = img + noise
|
721 |
+
if clip and rounds:
|
722 |
+
out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
|
723 |
+
elif clip:
|
724 |
+
out = torch.clamp(out, 0, 1)
|
725 |
+
elif rounds:
|
726 |
+
out = (out * 255.0).round() / 255.
|
727 |
+
return out
|
728 |
+
|
729 |
+
|
730 |
+
# ------------------------------------------------------------------------ #
|
731 |
+
# --------------------------- JPEG compression --------------------------- #
|
732 |
+
# ------------------------------------------------------------------------ #
|
733 |
+
|
734 |
+
|
735 |
+
def add_jpg_compression(img, quality=90):
|
736 |
+
"""Add JPG compression artifacts.
|
737 |
+
|
738 |
+
Args:
|
739 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
740 |
+
quality (float): JPG compression quality. 0 for lowest quality, 100 for
|
741 |
+
best quality. Default: 90.
|
742 |
+
|
743 |
+
Returns:
|
744 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
745 |
+
float32.
|
746 |
+
"""
|
747 |
+
img = np.clip(img, 0, 1)
|
748 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
|
749 |
+
_, encimg = cv2.imencode('.jpg', img * 255., encode_param)
|
750 |
+
img = np.float32(cv2.imdecode(encimg, 1)) / 255.
|
751 |
+
return img
|
752 |
+
|
753 |
+
|
754 |
+
def random_add_jpg_compression(img, quality_range=(90, 100)):
|
755 |
+
"""Randomly add JPG compression artifacts.
|
756 |
+
|
757 |
+
Args:
|
758 |
+
img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
|
759 |
+
quality_range (tuple[float] | list[float]): JPG compression quality
|
760 |
+
range. 0 for lowest quality, 100 for best quality.
|
761 |
+
Default: (90, 100).
|
762 |
+
|
763 |
+
Returns:
|
764 |
+
(Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
|
765 |
+
float32.
|
766 |
+
"""
|
767 |
+
quality = np.random.uniform(quality_range[0], quality_range[1])
|
768 |
+
return add_jpg_compression(img, quality)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import time
|
3 |
+
from os import path as osp
|
4 |
+
from torch.utils import data as data
|
5 |
+
from torchvision.transforms.functional import normalize
|
6 |
+
|
7 |
+
from r_basicsr.data.transforms import augment
|
8 |
+
from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
9 |
+
from r_basicsr.utils.registry import DATASET_REGISTRY
|
10 |
+
|
11 |
+
|
12 |
+
@DATASET_REGISTRY.register()
|
13 |
+
class FFHQDataset(data.Dataset):
|
14 |
+
"""FFHQ dataset for StyleGAN.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
18 |
+
dataroot_gt (str): Data root path for gt.
|
19 |
+
io_backend (dict): IO backend type and other kwarg.
|
20 |
+
mean (list | tuple): Image mean.
|
21 |
+
std (list | tuple): Image std.
|
22 |
+
use_hflip (bool): Whether to horizontally flip.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, opt):
|
27 |
+
super(FFHQDataset, self).__init__()
|
28 |
+
self.opt = opt
|
29 |
+
# file client (io backend)
|
30 |
+
self.file_client = None
|
31 |
+
self.io_backend_opt = opt['io_backend']
|
32 |
+
|
33 |
+
self.gt_folder = opt['dataroot_gt']
|
34 |
+
self.mean = opt['mean']
|
35 |
+
self.std = opt['std']
|
36 |
+
|
37 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
38 |
+
self.io_backend_opt['db_paths'] = self.gt_folder
|
39 |
+
if not self.gt_folder.endswith('.lmdb'):
|
40 |
+
raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
41 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
42 |
+
self.paths = [line.split('.')[0] for line in fin]
|
43 |
+
else:
|
44 |
+
# FFHQ has 70000 images in total
|
45 |
+
self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
if self.file_client is None:
|
49 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
50 |
+
|
51 |
+
# load gt image
|
52 |
+
gt_path = self.paths[index]
|
53 |
+
# avoid errors caused by high latency in reading files
|
54 |
+
retry = 3
|
55 |
+
while retry > 0:
|
56 |
+
try:
|
57 |
+
img_bytes = self.file_client.get(gt_path)
|
58 |
+
except Exception as e:
|
59 |
+
logger = get_root_logger()
|
60 |
+
logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
|
61 |
+
# change another file to read
|
62 |
+
index = random.randint(0, self.__len__())
|
63 |
+
gt_path = self.paths[index]
|
64 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
65 |
+
else:
|
66 |
+
break
|
67 |
+
finally:
|
68 |
+
retry -= 1
|
69 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
70 |
+
|
71 |
+
# random horizontal flip
|
72 |
+
img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
|
73 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
74 |
+
img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
|
75 |
+
# normalize
|
76 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
77 |
+
return {'gt': img_gt, 'gt_path': gt_path}
|
78 |
+
|
79 |
+
def __len__(self):
|
80 |
+
return len(self.paths)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils import data as data
|
2 |
+
from torchvision.transforms.functional import normalize
|
3 |
+
|
4 |
+
from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
|
5 |
+
from r_basicsr.data.transforms import augment, paired_random_crop
|
6 |
+
from r_basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
|
7 |
+
from r_basicsr.utils.registry import DATASET_REGISTRY
|
8 |
+
|
9 |
+
|
10 |
+
@DATASET_REGISTRY.register()
|
11 |
+
class PairedImageDataset(data.Dataset):
|
12 |
+
"""Paired image dataset for image restoration.
|
13 |
+
|
14 |
+
Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
|
15 |
+
|
16 |
+
There are three modes:
|
17 |
+
1. 'lmdb': Use lmdb files.
|
18 |
+
If opt['io_backend'] == lmdb.
|
19 |
+
2. 'meta_info_file': Use meta information file to generate paths.
|
20 |
+
If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
|
21 |
+
3. 'folder': Scan folders to generate paths.
|
22 |
+
The rest.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
26 |
+
dataroot_gt (str): Data root path for gt.
|
27 |
+
dataroot_lq (str): Data root path for lq.
|
28 |
+
meta_info_file (str): Path for meta information file.
|
29 |
+
io_backend (dict): IO backend type and other kwarg.
|
30 |
+
filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
|
31 |
+
Default: '{}'.
|
32 |
+
gt_size (int): Cropped patched size for gt patches.
|
33 |
+
use_hflip (bool): Use horizontal flips.
|
34 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
35 |
+
|
36 |
+
scale (bool): Scale, which will be added automatically.
|
37 |
+
phase (str): 'train' or 'val'.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, opt):
|
41 |
+
super(PairedImageDataset, self).__init__()
|
42 |
+
self.opt = opt
|
43 |
+
# file client (io backend)
|
44 |
+
self.file_client = None
|
45 |
+
self.io_backend_opt = opt['io_backend']
|
46 |
+
self.mean = opt['mean'] if 'mean' in opt else None
|
47 |
+
self.std = opt['std'] if 'std' in opt else None
|
48 |
+
|
49 |
+
self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
|
50 |
+
if 'filename_tmpl' in opt:
|
51 |
+
self.filename_tmpl = opt['filename_tmpl']
|
52 |
+
else:
|
53 |
+
self.filename_tmpl = '{}'
|
54 |
+
|
55 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
56 |
+
self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
|
57 |
+
self.io_backend_opt['client_keys'] = ['lq', 'gt']
|
58 |
+
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
|
59 |
+
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
|
60 |
+
self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
|
61 |
+
self.opt['meta_info_file'], self.filename_tmpl)
|
62 |
+
else:
|
63 |
+
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
|
64 |
+
|
65 |
+
def __getitem__(self, index):
|
66 |
+
if self.file_client is None:
|
67 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
68 |
+
|
69 |
+
scale = self.opt['scale']
|
70 |
+
|
71 |
+
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
72 |
+
# image range: [0, 1], float32.
|
73 |
+
gt_path = self.paths[index]['gt_path']
|
74 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
75 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
76 |
+
lq_path = self.paths[index]['lq_path']
|
77 |
+
img_bytes = self.file_client.get(lq_path, 'lq')
|
78 |
+
img_lq = imfrombytes(img_bytes, float32=True)
|
79 |
+
|
80 |
+
# augmentation for training
|
81 |
+
if self.opt['phase'] == 'train':
|
82 |
+
gt_size = self.opt['gt_size']
|
83 |
+
# random crop
|
84 |
+
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
|
85 |
+
# flip, rotation
|
86 |
+
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
|
87 |
+
|
88 |
+
# color space transform
|
89 |
+
if 'color' in self.opt and self.opt['color'] == 'y':
|
90 |
+
img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
|
91 |
+
img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
|
92 |
+
|
93 |
+
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
|
94 |
+
# TODO: It is better to update the datasets, rather than force to crop
|
95 |
+
if self.opt['phase'] != 'train':
|
96 |
+
img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
|
97 |
+
|
98 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
99 |
+
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
|
100 |
+
# normalize
|
101 |
+
if self.mean is not None or self.std is not None:
|
102 |
+
normalize(img_lq, self.mean, self.std, inplace=True)
|
103 |
+
normalize(img_gt, self.mean, self.std, inplace=True)
|
104 |
+
|
105 |
+
return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.paths)
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Ref:
|
11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
generator: Python generator.
|
15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, generator, num_prefetch_queue):
|
19 |
+
threading.Thread.__init__(self)
|
20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
21 |
+
self.generator = generator
|
22 |
+
self.daemon = True
|
23 |
+
self.start()
|
24 |
+
|
25 |
+
def run(self):
|
26 |
+
for item in self.generator:
|
27 |
+
self.queue.put(item)
|
28 |
+
self.queue.put(None)
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
next_item = self.queue.get()
|
32 |
+
if next_item is None:
|
33 |
+
raise StopIteration
|
34 |
+
return next_item
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PrefetchDataLoader(DataLoader):
|
41 |
+
"""Prefetch version of dataloader.
|
42 |
+
|
43 |
+
Ref:
|
44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
45 |
+
|
46 |
+
TODO:
|
47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
48 |
+
ddp.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
52 |
+
kwargs (dict): Other arguments for dataloader.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
61 |
+
|
62 |
+
|
63 |
+
class CPUPrefetcher():
|
64 |
+
"""CPU prefetcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
loader: Dataloader.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.ori_loader = loader
|
72 |
+
self.loader = iter(loader)
|
73 |
+
|
74 |
+
def next(self):
|
75 |
+
try:
|
76 |
+
return next(self.loader)
|
77 |
+
except StopIteration:
|
78 |
+
return None
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.loader = iter(self.ori_loader)
|
82 |
+
|
83 |
+
|
84 |
+
class CUDAPrefetcher():
|
85 |
+
"""CUDA prefetcher.
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
89 |
+
|
90 |
+
It may consums more GPU memory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loader: Dataloader.
|
94 |
+
opt (dict): Options.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, loader, opt):
|
98 |
+
self.ori_loader = loader
|
99 |
+
self.loader = iter(loader)
|
100 |
+
self.opt = opt
|
101 |
+
self.stream = torch.cuda.Stream()
|
102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
103 |
+
self.preload()
|
104 |
+
|
105 |
+
def preload(self):
|
106 |
+
try:
|
107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
108 |
+
except StopIteration:
|
109 |
+
self.batch = None
|
110 |
+
return None
|
111 |
+
# put tensors to gpu
|
112 |
+
with torch.cuda.stream(self.stream):
|
113 |
+
for k, v in self.batch.items():
|
114 |
+
if torch.is_tensor(v):
|
115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
116 |
+
|
117 |
+
def next(self):
|
118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
119 |
+
batch = self.batch
|
120 |
+
self.preload()
|
121 |
+
return batch
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.loader = iter(self.ori_loader)
|
125 |
+
self.preload()
|
custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import os.path as osp
|
6 |
+
import random
|
7 |
+
import time
|
8 |
+
import torch
|
9 |
+
from torch.utils import data as data
|
10 |
+
|
11 |
+
from r_basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
|
12 |
+
from r_basicsr.data.transforms import augment
|
13 |
+
from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
|
14 |
+
from r_basicsr.utils.registry import DATASET_REGISTRY
|
15 |
+
|
16 |
+
|
17 |
+
@DATASET_REGISTRY.register(suffix='basicsr')
|
18 |
+
class RealESRGANDataset(data.Dataset):
|
19 |
+
"""Dataset used for Real-ESRGAN model:
|
20 |
+
Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
|
21 |
+
|
22 |
+
It loads gt (Ground-Truth) images, and augments them.
|
23 |
+
It also generates blur kernels and sinc kernels for generating low-quality images.
|
24 |
+
Note that the low-quality images are processed in tensors on GPUS for faster processing.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
opt (dict): Config for train datasets. It contains the following keys:
|
28 |
+
dataroot_gt (str): Data root path for gt.
|
29 |
+
meta_info (str): Path for meta information file.
|
30 |
+
io_backend (dict): IO backend type and other kwarg.
|
31 |
+
use_hflip (bool): Use horizontal flips.
|
32 |
+
use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
|
33 |
+
Please see more options in the codes.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
super(RealESRGANDataset, self).__init__()
|
38 |
+
self.opt = opt
|
39 |
+
self.file_client = None
|
40 |
+
self.io_backend_opt = opt['io_backend']
|
41 |
+
self.gt_folder = opt['dataroot_gt']
|
42 |
+
|
43 |
+
# file client (lmdb io backend)
|
44 |
+
if self.io_backend_opt['type'] == 'lmdb':
|
45 |
+
self.io_backend_opt['db_paths'] = [self.gt_folder]
|
46 |
+
self.io_backend_opt['client_keys'] = ['gt']
|
47 |
+
if not self.gt_folder.endswith('.lmdb'):
|
48 |
+
raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
|
49 |
+
with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
|
50 |
+
self.paths = [line.split('.')[0] for line in fin]
|
51 |
+
else:
|
52 |
+
# disk backend with meta_info
|
53 |
+
# Each line in the meta_info describes the relative path to an image
|
54 |
+
with open(self.opt['meta_info']) as fin:
|
55 |
+
paths = [line.strip().split(' ')[0] for line in fin]
|
56 |
+
self.paths = [os.path.join(self.gt_folder, v) for v in paths]
|
57 |
+
|
58 |
+
# blur settings for the first degradation
|
59 |
+
self.blur_kernel_size = opt['blur_kernel_size']
|
60 |
+
self.kernel_list = opt['kernel_list']
|
61 |
+
self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
|
62 |
+
self.blur_sigma = opt['blur_sigma']
|
63 |
+
self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
|
64 |
+
self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
|
65 |
+
self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
|
66 |
+
|
67 |
+
# blur settings for the second degradation
|
68 |
+
self.blur_kernel_size2 = opt['blur_kernel_size2']
|
69 |
+
self.kernel_list2 = opt['kernel_list2']
|
70 |
+
self.kernel_prob2 = opt['kernel_prob2']
|
71 |
+
self.blur_sigma2 = opt['blur_sigma2']
|
72 |
+
self.betag_range2 = opt['betag_range2']
|
73 |
+
self.betap_range2 = opt['betap_range2']
|
74 |
+
self.sinc_prob2 = opt['sinc_prob2']
|
75 |
+
|
76 |
+
# a final sinc filter
|
77 |
+
self.final_sinc_prob = opt['final_sinc_prob']
|
78 |
+
|
79 |
+
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
|
80 |
+
# TODO: kernel range is now hard-coded, should be in the configure file
|
81 |
+
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
|
82 |
+
self.pulse_tensor[10, 10] = 1
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
if self.file_client is None:
|
86 |
+
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
|
87 |
+
|
88 |
+
# -------------------------------- Load gt images -------------------------------- #
|
89 |
+
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
|
90 |
+
gt_path = self.paths[index]
|
91 |
+
# avoid errors caused by high latency in reading files
|
92 |
+
retry = 3
|
93 |
+
while retry > 0:
|
94 |
+
try:
|
95 |
+
img_bytes = self.file_client.get(gt_path, 'gt')
|
96 |
+
except (IOError, OSError) as e:
|
97 |
+
logger = get_root_logger()
|
98 |
+
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
|
99 |
+
# change another file to read
|
100 |
+
index = random.randint(0, self.__len__())
|
101 |
+
gt_path = self.paths[index]
|
102 |
+
time.sleep(1) # sleep 1s for occasional server congestion
|
103 |
+
else:
|
104 |
+
break
|
105 |
+
finally:
|
106 |
+
retry -= 1
|
107 |
+
img_gt = imfrombytes(img_bytes, float32=True)
|
108 |
+
|
109 |
+
# -------------------- Do augmentation for training: flip, rotation -------------------- #
|
110 |
+
img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
|
111 |
+
|
112 |
+
# crop or pad to 400
|
113 |
+
# TODO: 400 is hard-coded. You may change it accordingly
|
114 |
+
h, w = img_gt.shape[0:2]
|
115 |
+
crop_pad_size = 400
|
116 |
+
# pad
|
117 |
+
if h < crop_pad_size or w < crop_pad_size:
|
118 |
+
pad_h = max(0, crop_pad_size - h)
|
119 |
+
pad_w = max(0, crop_pad_size - w)
|
120 |
+
img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
|
121 |
+
# crop
|
122 |
+
if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
|
123 |
+
h, w = img_gt.shape[0:2]
|
124 |
+
# randomly choose top and left coordinates
|
125 |
+
top = random.randint(0, h - crop_pad_size)
|
126 |
+
left = random.randint(0, w - crop_pad_size)
|
127 |
+
img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
|
128 |
+
|
129 |
+
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
|
130 |
+
kernel_size = random.choice(self.kernel_range)
|
131 |
+
if np.random.uniform() < self.opt['sinc_prob']:
|
132 |
+
# this sinc filter setting is for kernels ranging from [7, 21]
|
133 |
+
if kernel_size < 13:
|
134 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
135 |
+
else:
|
136 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
137 |
+
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
138 |
+
else:
|
139 |
+
kernel = random_mixed_kernels(
|
140 |
+
self.kernel_list,
|
141 |
+
self.kernel_prob,
|
142 |
+
kernel_size,
|
143 |
+
self.blur_sigma,
|
144 |
+
self.blur_sigma, [-math.pi, math.pi],
|
145 |
+
self.betag_range,
|
146 |
+
self.betap_range,
|
147 |
+
noise_range=None)
|
148 |
+
# pad kernel
|
149 |
+
pad_size = (21 - kernel_size) // 2
|
150 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
151 |
+
|
152 |
+
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
|
153 |
+
kernel_size = random.choice(self.kernel_range)
|
154 |
+
if np.random.uniform() < self.opt['sinc_prob2']:
|
155 |
+
if kernel_size < 13:
|
156 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
157 |
+
else:
|
158 |
+
omega_c = np.random.uniform(np.pi / 5, np.pi)
|
159 |
+
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
|
160 |
+
else:
|
161 |
+
kernel2 = random_mixed_kernels(
|
162 |
+
self.kernel_list2,
|
163 |
+
self.kernel_prob2,
|
164 |
+
kernel_size,
|
165 |
+
self.blur_sigma2,
|
166 |
+
self.blur_sigma2, [-math.pi, math.pi],
|
167 |
+
self.betag_range2,
|
168 |
+
self.betap_range2,
|
169 |
+
noise_range=None)
|
170 |
+
|
171 |
+
# pad kernel
|
172 |
+
pad_size = (21 - kernel_size) // 2
|
173 |
+
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
|
174 |
+
|
175 |
+
# ------------------------------------- the final sinc kernel ------------------------------------- #
|
176 |
+
if np.random.uniform() < self.opt['final_sinc_prob']:
|
177 |
+
kernel_size = random.choice(self.kernel_range)
|
178 |
+
omega_c = np.random.uniform(np.pi / 3, np.pi)
|
179 |
+
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
|
180 |
+
sinc_kernel = torch.FloatTensor(sinc_kernel)
|
181 |
+
else:
|
182 |
+
sinc_kernel = self.pulse_tensor
|
183 |
+
|
184 |
+
# BGR to RGB, HWC to CHW, numpy to tensor
|
185 |
+
img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
|
186 |
+
kernel = torch.FloatTensor(kernel)
|
187 |
+
kernel2 = torch.FloatTensor(kernel2)
|
188 |
+
|
189 |
+
return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
|
190 |
+
return return_d
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.paths)
|