model2 commited on
Commit
6527198
·
1 Parent(s): 82c6075

Add reactor

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +0 -1
  2. app.py +2 -2
  3. custom_nodes/ComfyUI-ReActor/.gitignore +5 -0
  4. custom_nodes/ComfyUI-ReActor/LICENSE +674 -0
  5. custom_nodes/ComfyUI-ReActor/README.md +488 -0
  6. custom_nodes/ComfyUI-ReActor/README_RU.md +497 -0
  7. custom_nodes/ComfyUI-ReActor/__init__.py +39 -0
  8. custom_nodes/ComfyUI-ReActor/install.bat +37 -0
  9. custom_nodes/ComfyUI-ReActor/install.py +104 -0
  10. custom_nodes/ComfyUI-ReActor/modules/__init__.py +0 -0
  11. custom_nodes/ComfyUI-ReActor/modules/images.py +0 -0
  12. custom_nodes/ComfyUI-ReActor/modules/processing.py +13 -0
  13. custom_nodes/ComfyUI-ReActor/modules/scripts.py +13 -0
  14. custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py +0 -0
  15. custom_nodes/ComfyUI-ReActor/modules/shared.py +19 -0
  16. custom_nodes/ComfyUI-ReActor/nodes.py +1364 -0
  17. custom_nodes/ComfyUI-ReActor/pyproject.toml +15 -0
  18. custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py +12 -0
  19. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py +25 -0
  20. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py +322 -0
  21. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py +336 -0
  22. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py +407 -0
  23. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py +169 -0
  24. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py +162 -0
  25. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py +150 -0
  26. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py +277 -0
  27. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py +274 -0
  28. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py +61 -0
  29. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py +383 -0
  30. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py +259 -0
  31. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py +255 -0
  32. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py +307 -0
  33. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py +135 -0
  34. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py +184 -0
  35. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py +119 -0
  36. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py +96 -0
  37. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py +65 -0
  38. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py +70 -0
  39. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py +799 -0
  40. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py +956 -0
  41. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py +172 -0
  42. custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py +161 -0
  43. custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py +101 -0
  44. custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py +48 -0
  45. custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py +313 -0
  46. custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py +768 -0
  47. custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py +80 -0
  48. custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py +108 -0
  49. custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py +125 -0
  50. custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py +193 -0
.gitignore CHANGED
@@ -5,7 +5,6 @@ __pycache__/
5
  !/input/example.png
6
  /models/
7
  /temp/
8
- /custom_nodes/
9
  !custom_nodes/example_node.py.example
10
  extra_model_paths.yaml
11
  /.vs
 
5
  !/input/example.png
6
  /models/
7
  /temp/
 
8
  !custom_nodes/example_node.py.example
9
  extra_model_paths.yaml
10
  /.vs
app.py CHANGED
@@ -6,10 +6,10 @@ import gradio as gr
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from nodes import NODE_CLASS_MAPPINGS
9
- import spaces
10
  from comfy import model_management
11
 
12
- @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
 
13
 
14
 
15
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
 
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from nodes import NODE_CLASS_MAPPINGS
 
9
  from comfy import model_management
10
 
11
+ # import spaces
12
+ # @spaces.GPU(duration=60) #modify the duration for the average it takes for your worflow to run, in seconds
13
 
14
 
15
  def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
custom_nodes/ComfyUI-ReActor/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *$py.class
3
+ .vscode/
4
+ example
5
+ input
custom_nodes/ComfyUI-ReActor/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
custom_nodes/ComfyUI-ReActor/README.md ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="https://github.com/Gourieff/Assets/raw/main/sd-webui-reactor/ReActor_logo_NEW_EN.png?raw=true" alt="logo" width="180px"/>
4
+
5
+ ![Version](https://img.shields.io/badge/node_version-0.6.0_alpha1-lightgreen?style=for-the-badge&labelColor=darkgreen)
6
+
7
+ <!--<sup>
8
+ <font color=brightred>
9
+
10
+ ## !!! [Important Update](#latestupdate) !!!<br>Don't forget to add the Node again in existing workflows
11
+
12
+ </font>
13
+ </sup>-->
14
+
15
+ <a href="https://boosty.to/artgourieff" target="_blank">
16
+ <img src="https://lovemet.ru/img/boosty.jpg" width="108" alt="Support Me on Boosty"/>
17
+ <br>
18
+ <sup>
19
+ Support This Project
20
+ </sup>
21
+ </a>
22
+
23
+ <hr>
24
+
25
+ [![Commit activity](https://img.shields.io/github/commit-activity/t/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/commits/main)
26
+ ![Last commit](https://img.shields.io/github/last-commit/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)
27
+ [![Opened issues](https://img.shields.io/github/issues/Gourieff/ComfyUI-ReActor?color=red)](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0)
28
+ [![Closed issues](https://img.shields.io/github/issues-closed/Gourieff/ComfyUI-ReActor?color=green&cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed)
29
+ ![License](https://img.shields.io/github/license/Gourieff/ComfyUI-ReActor)
30
+
31
+ English | [Русский](/README_RU.md)
32
+
33
+ # ReActor Nodes for ComfyUI<br><sub><sup>-=SFW-Friendly=-</sup></sub>
34
+
35
+ </div>
36
+
37
+ ### The Fast and Simple Face Swap Extension Nodes for ComfyUI, based on [blocked ReActor](https://github.com/Gourieff/comfyui-reactor-node) - now it has a nudity detector to avoid using this software with 18+ content
38
+
39
+ > By using this Node you accept and assume [responsibility](#disclaimer)
40
+
41
+ <div align="center">
42
+
43
+ ---
44
+ [**What's new**](#latestupdate) | [**Installation**](#installation) | [**Usage**](#usage) | [**Troubleshooting**](#troubleshooting) | [**Updating**](#updating) | [**Disclaimer**](#disclaimer) | [**Credits**](#credits) | [**Note!**](#note)
45
+
46
+ ---
47
+
48
+ </div>
49
+
50
+ <a name="latestupdate">
51
+
52
+ ## What's new in the latest update
53
+
54
+ ### 0.6.0 <sub><sup>ALPHA1</sup></sub>
55
+
56
+ - New Node `ReActorSetWeight` - you can now set the strength of face swap for `source_image` or `face_model` from 0% to 100% (in 12.5% step)
57
+
58
+ <center>
59
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-01.jpg?raw=true" alt="0.6.0-whatsnew-01" width="100%"/>
60
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-02.jpg?raw=true" alt="0.6.0-whatsnew-02" width="100%"/>
61
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-alpha1-01.gif?raw=true" alt="0.6.0-whatsnew-03" width="540px"/>
62
+ </center>
63
+
64
+ <details>
65
+ <summary><a>Previous versions</a></summary>
66
+
67
+ ### 0.5.2
68
+
69
+ - ReSwapper models support. Although Inswapper still has the best similarity, but ReSwapper is evolving - thanks @somanchiu https://github.com/somanchiu/ReSwapper for the ReSwapper models and the ReSwapper project! This is a good step for the Community in the Inswapper's alternative creation!
70
+
71
+ <center>
72
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-03.jpg?raw=true" alt="0.5.2-whatsnew-03" width="75%"/>
73
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-04.jpg?raw=true" alt="0.5.2-whatsnew-04" width="75%"/>
74
+ </center>
75
+
76
+ You can download ReSwapper models here:
77
+ https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models
78
+ Just put them into the "models/reswapper" directory.
79
+
80
+ - NSFW-detector to not violate [GitHub rules](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools)
81
+ - New node "Unload ReActor Models" - is useful for complex WFs when you need to free some VRAM utilized by ReActor
82
+
83
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-01.jpg?raw=true" alt="0.5.2-whatsnew-01" width="100%"/>
84
+
85
+ - Support of ORT CoreML and ROCM EPs, just install onnxruntime version you need
86
+ - Install script improvements to install latest versions of ORT-GPU
87
+
88
+ <center>
89
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-02.jpg?raw=true" alt="0.5.2-whatsnew-02" width="50%"/>
90
+ </center>
91
+
92
+ - Fixes and improvements
93
+
94
+
95
+ ### 0.5.1
96
+
97
+ - Support of GPEN 1024/2048 restoration models (available in the HF dataset https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models)
98
+ - ReActorFaceBoost Node - an attempt to improve the quality of swapped faces. The idea is to restore and scale the swapped face (according to the `face_size` parameter of the restoration model) BEFORE pasting it to the target image (via inswapper algorithms), more information is [here (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321)
99
+
100
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-01.jpg?raw=true" alt="0.5.1-whatsnew-01" width="100%"/>
101
+
102
+ [Full size demo preview](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png)
103
+
104
+ - Sorting facemodels alphabetically
105
+ - A lot of fixes and improvements
106
+
107
+ ### [0.5.0 <sub><sup>BETA4</sup></sub>](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0)
108
+
109
+ - Spandrel lib support for GFPGAN
110
+
111
+ ### 0.5.0 <sub><sup>BETA3</sup></sub>
112
+
113
+ - Fixes: "RAM issue", "No detection" for MaskingHelper
114
+
115
+ ### 0.5.0 <sub><sup>BETA2</sup></sub>
116
+
117
+ - You can now build a blended face model from a batch of face models you already have, just add the "Make Face Model Batch" node to your workflow and connect several models via "Load Face Model"
118
+ - Huge performance boost of the image analyzer's module! 10x speed up! Working with videos is now a pleasure!
119
+
120
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-05.png?raw=true" alt="0.5.0-whatsnew-05" width="100%"/>
121
+
122
+ ### 0.5.0 <sub><sup>BETA1</sup></sub>
123
+
124
+ - SWAPPED_FACE output for the Masking Helper Node
125
+ - FIX: Empty A-channel for Masking Helper IMAGE output (causing errors with some nodes) was removed
126
+
127
+ ### 0.5.0 <sub><sup>ALPHA1</sup></sub>
128
+
129
+ - ReActorBuildFaceModel Node got "face_model" output to provide a blended face model directly to the main Node:
130
+
131
+ Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json)
132
+
133
+ - Face Masking feature is available now, just add the "ReActorMaskHelper" Node to the workflow and connect it as shown below:
134
+
135
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-01.jpg?raw=true" alt="0.5.0-whatsnew-01" width="100%"/>
136
+
137
+ If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory
138
+ <br>
139
+ As well as ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) model - download (if you don't have it) and put it into the "ComfyUI\models\sams" directory;
140
+
141
+ Use this Node to gain the best results of the face swapping process:
142
+
143
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-02.jpg?raw=true" alt="0.5.0-whatsnew-02" width="100%"/>
144
+
145
+ - ReActorImageDublicator Node - rather useful for those who create videos, it helps to duplicate one image to several frames to use them with VAE Encoder (e.g. live avatars):
146
+
147
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-03.jpg?raw=true" alt="0.5.0-whatsnew-03" width="100%"/>
148
+
149
+ - ReActorFaceSwapOpt (a simplified version of the Main Node) + ReActorOptions Nodes to set some additional options such as (new) "input/source faces separate order". Yes! You can now set the order of faces in the index in the way you want ("large to small" goes by default)!
150
+
151
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-04.jpg?raw=true" alt="0.5.0-whatsnew-04" width="100%"/>
152
+
153
+ - Little speed boost when analyzing target images (unfortunately it is still quite slow in compare to swapping and restoring...)
154
+
155
+ ### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2)
156
+
157
+ - GPEN-BFR-512 and RestoreFormer_Plus_Plus face restoration models support
158
+
159
+ You can download models here: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
160
+ <br>Put them into the `ComfyUI\models\facerestore_models` folder
161
+
162
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-04.jpg?raw=true" alt="0.4.2-whatsnew-04" width="100%"/>
163
+
164
+ - Due to popular demand - you can now blend several images with persons into one face model file and use it with "Load Face Model" Node or in SD WebUI as well;
165
+
166
+ Experiment and create new faces or blend faces of one person to gain better accuracy and likeness!
167
+
168
+ Just add the ImpactPack's "Make Image Batch" Node as the input to the ReActor's one and load images you want to blend into one model:
169
+
170
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-01.jpg?raw=true" alt="0.4.2-whatsnew-01" width="100%"/>
171
+
172
+ Result example (the new face was created from 4 faces of different actresses):
173
+
174
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-02.jpg?raw=true" alt="0.4.2-whatsnew-02" width="75%"/>
175
+
176
+ Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json)
177
+
178
+ ### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1)
179
+
180
+ - CUDA 12 Support - don't forget to run (Windows) `install.bat` or (Linux/MacOS) `install.py` for ComfyUI's Python enclosure or try to install ORT-GPU for CU12 manually (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
181
+ - Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173 fix
182
+
183
+ - Separate Node for the Face Restoration postprocessing (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), can be found inside ReActor's menu (RestoreFace Node)
184
+ - (Windows) Installation can be done for Python from the System's PATH
185
+ - Different fixes and improvements
186
+
187
+ - Face Restore Visibility and CodeFormer Weight (Fidelity) options are now available! Don't forget to reload the Node in your existing workflow
188
+
189
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.1-whatsnew-01.jpg?raw=true" alt="0.4.1-whatsnew-01" width="100%"/>
190
+
191
+ ### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0)
192
+
193
+ - Input "input_image" goes first now, it gives a correct bypass and also it is right to have the main input first;
194
+ - You can now save face models as "safetensors" files (`ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use:
195
+
196
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-01.jpg?raw=true" alt="0.4.0-whatsnew-01" width="100%"/>
197
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-02.jpg?raw=true" alt="0.4.0-whatsnew-02" width="100%"/>
198
+
199
+ - Ability to build and save face models directly from an image:
200
+
201
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-03.jpg?raw=true" alt="0.4.0-whatsnew-03" width="50%"/>
202
+
203
+ - Both the inputs are optional, just connect one of them according to your workflow; if both is connected - `image` has a priority.
204
+ - Different fixes making this extension better.
205
+
206
+ Thanks to everyone who finds bugs, suggests new features and supports this project!
207
+
208
+ </details>
209
+
210
+ ## Installation
211
+
212
+ <details>
213
+ <summary>SD WebUI: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/">AUTOMATIC1111</a> or <a href="https://github.com/vladmandic/automatic">SD.Next</a></summary>
214
+
215
+ 1. Close (stop) your SD-WebUI/Comfy Server if it's running
216
+ 2. (For Windows Users):
217
+ - Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface)
218
+ - OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile"
219
+ - OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild)
220
+ 3. Go to the `extensions\sd-webui-comfyui\ComfyUI\custom_nodes`
221
+ 4. Open Console or Terminal and run `git clone https://github.com/Gourieff/ComfyUI-ReActor`
222
+ 5. Go to the SD WebUI root folder, open Console or Terminal and run (Windows users)`.\venv\Scripts\activate` or (Linux/MacOS)`venv/bin/activate`
223
+ 6. `python -m pip install -U pip`
224
+ 7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor`
225
+ 8. `python install.py`
226
+ 9. Please, wait until the installation process will be finished
227
+ 10. (From the version 0.3.0) Download additional facerestorers models from the link below and put them into the `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models` directory:<br>
228
+ https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
229
+ 11. Run SD WebUI and check console for the message that ReActor Node is running:
230
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/console_status_running.jpg?raw=true" alt="console_status_running" width="759"/>
231
+
232
+ 12. Go to the ComfyUI tab and find there ReActor Node inside the menu `ReActor` or by using a search:
233
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/webui-demo.png?raw=true" alt="webui-demo" width="100%"/>
234
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/search-demo.png?raw=true" alt="webui-demo" width="1043"/>
235
+
236
+ </details>
237
+
238
+ <details>
239
+ <summary>Standalone (Portable) <a href="https://github.com/comfyanonymous/ComfyUI">ComfyUI</a> for Windows</summary>
240
+
241
+ 1. Do the following:
242
+ - Install [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Community version - you need this step to build Insightface)
243
+ - OR only [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) and select "Desktop Development with C++" under "Workloads -> Desktop & Mobile"
244
+ - OR if you don't want to install VS or VS C++ BT - follow [this steps (sec. I)](#insightfacebuild)
245
+ 2. Choose between two options:
246
+ - (ComfyUI Manager) Open ComfyUI Manager, click "Install Custom Nodes", type "ReActor" in the "Search" field and then click "Install". After ComfyUI will complete the process - please restart the Server.
247
+ - (Manually) Go to `ComfyUI\custom_nodes`, open Console and run `git clone https://github.com/Gourieff/ComfyUI-ReActor`
248
+ 3. Go to `ComfyUI\custom_nodes\ComfyUI-ReActor` and run `install.bat`
249
+ 4. If you don't have the "face_yolov8m.pt" Ultralytics model - you can download it from the [Assets](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) and put it into the "ComfyUI\models\ultralytics\bbox" directory<br>As well as one or both of "Sams" models from [here](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - download (if you don't have them) and put into the "ComfyUI\models\sams" directory
250
+ 5. Run ComfyUI and find there ReActor Nodes inside the menu `ReActor` or by using a search
251
+
252
+ </details>
253
+
254
+ ## Usage
255
+
256
+ You can find ReActor Nodes inside the menu `ReActor` or by using a search (just type "ReActor" in the search field)
257
+
258
+ List of Nodes:
259
+ - ••• Main Nodes •••
260
+ - ReActorFaceSwap (Main Node)
261
+ - ReActorFaceSwapOpt (Main Node with the additional Options input)
262
+ - ReActorOptions (Options for ReActorFaceSwapOpt)
263
+ - ReActorFaceBoost (Face Booster Node)
264
+ - ReActorMaskHelper (Masking Helper)
265
+ - ••• Operations with Face Models •••
266
+ - ReActorSaveFaceModel (Save Face Model)
267
+ - ReActorLoadFaceModel (Load Face Model)
268
+ - ReActorBuildFaceModel (Build Blended Face Model)
269
+ - ReActorMakeFaceModelBatch (Make Face Model Batch)
270
+ - ••• Additional Nodes •••
271
+ - ReActorRestoreFace (Face Restoration)
272
+ - ReActorImageDublicator (Dublicate one Image to Images List)
273
+ - ImageRGBA2RGB (Convert RGBA to RGB)
274
+
275
+ Connect all required slots and run the query.
276
+
277
+ ### Main Node Inputs
278
+
279
+ - `input_image` - is an image to be processed (target image, analog of "target image" in the SD WebUI extension);
280
+ - Supported Nodes: "Load Image", "Load Video" or any other nodes providing images as an output;
281
+ - `source_image` - is an image with a face or faces to swap in the `input_image` (source image, analog of "source image" in the SD WebUI extension);
282
+ - Supported Nodes: "Load Image" or any other nodes providing images as an output;
283
+ - `face_model` - is the input for the "Load Face Model" Node or another ReActor node to provide a face model file (face embedding) you created earlier via the "Save Face Model" Node;
284
+ - Supported Nodes: "Load Face Model", "Build Blended Face Model";
285
+
286
+ ### Main Node Outputs
287
+
288
+ - `IMAGE` - is an output with the resulted image;
289
+ - Supported Nodes: any nodes which have images as an input;
290
+ - `FACE_MODEL` - is an output providing a source face's model being built during the swapping process;
291
+ - Supported Nodes: "Save Face Model", "ReActor", "Make Face Model Batch";
292
+
293
+ ### Face Restoration
294
+
295
+ Since version 0.3.0 ReActor Node has a buil-in face restoration.<br>Just download the models you want (see [Installation](#installation) instruction) and select one of them to restore the resulting face(s) during the faceswap. It will enhance face details and make your result more accurate.
296
+
297
+ ### Face Indexes
298
+
299
+ By default ReActor detects faces in images from "large" to "small".<br>You can change this option by adding ReActorFaceSwapOpt node with ReActorOptions.
300
+
301
+ And if you need to specify faces, you can set indexes for source and input images.
302
+
303
+ Index of the first detected face is 0.
304
+
305
+ You can set indexes in the order you need.<br>
306
+ E.g.: 0,1,2 (for Source); 1,0,2 (for Input).<br>This means: the second Input face (index = 1) will be swapped by the first Source face (index = 0) and so on.
307
+
308
+ ### Genders
309
+
310
+ You can specify the gender to detect in images.<br>
311
+ ReActor will swap a face only if it meets the given condition.
312
+
313
+ ### Face Models
314
+
315
+ Since version 0.4.0 you can save face models as "safetensors" files (stored in `ComfyUI\models\reactor\faces`) and load them into ReActor implementing different scenarios and keeping super lightweight face models of the faces you use.
316
+
317
+ To make new models appear in the list of the "Load Face Model" Node - just refresh the page of your ComfyUI web application.<br>
318
+ (I recommend you to use ComfyUI Manager - otherwise you workflow can be lost after you refresh the page if you didn't save it before that).
319
+
320
+ ## Troubleshooting
321
+
322
+ <a name="insightfacebuild">
323
+
324
+ ### **I. (For Windows users) If you still cannot build Insightface for some reasons or just don't want to install Visual Studio or VS C++ Build Tools - do the following:**
325
+
326
+ 1. (ComfyUI Portable) From the root folder check the version of Python:<br>run CMD and type `python_embeded\python.exe -V`
327
+ 2. Download prebuilt Insightface package [for Python 3.10](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp310-cp310-win_amd64.whl) or [for Python 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (if in the previous step you see 3.11) or [for Python 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (if in the previous step you see 3.12) and put into the stable-diffusion-webui (A1111 or SD.Next) root folder (where you have "webui-user.bat" file) or into ComfyUI root folder if you use ComfyUI Portable
328
+ 3. From the root folder run:
329
+ - (SD WebUI) CMD and `.\venv\Scripts\activate`
330
+ - (ComfyUI Portable) run CMD
331
+ 4. Then update your PIP:
332
+ - (SD WebUI) `python -m pip install -U pip`
333
+ - (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip`
334
+ 5. Then install Insightface:
335
+ - (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
336
+ - (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (for 3.10) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (for 3.11) or `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
337
+ 6. Enjoy!
338
+
339
+ ### **II. "AttributeError: 'NoneType' object has no attribute 'get'"**
340
+
341
+ This error may occur if there's smth wrong with the model file `inswapper_128.onnx`
342
+
343
+ Try to download it manually from [here](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx)
344
+ and put it to the `ComfyUI\models\insightface` replacing existing one
345
+
346
+ ### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"**
347
+
348
+ This means that input points have been changed with the latest update<br>
349
+ Remove the current ReActor Node from your workflow and add it again
350
+
351
+ ### **IV. ControlNet Aux Node IMPORT failed error when using with ReActor Node**
352
+
353
+ 1. Close ComfyUI if it runs
354
+ 2. Go to the ComfyUI root folder, open CMD there and run:
355
+ - `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless`
356
+ - `python_embeded\python.exe -m pip install opencv-python==4.7.0.72`
357
+ 3. That's it!
358
+
359
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/reactor-w-controlnet.png?raw=true" alt="reactor+controlnet" />
360
+
361
+ ### **V. "ModuleNotFoundError: No module named 'basicsr'" or "subprocess-exited-with-error" during future-0.18.3 installation**
362
+
363
+ - Download https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl<br>
364
+ - Put it to ComfyUI root And run:
365
+
366
+ python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl
367
+
368
+ - Then:
369
+
370
+ python_embeded\python.exe -m pip install basicsr
371
+
372
+ ### **VI. "fatal: fetch-pack: invalid index-pack output" when you try to `git clone` the repository"**
373
+
374
+ Try to clone with `--depth=1` (last commit only):
375
+
376
+ git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor
377
+
378
+ Then retrieve the rest (if you need):
379
+
380
+ git fetch --unshallow
381
+
382
+ ## Updating
383
+
384
+ Just put .bat or .sh script from this [Repo](https://github.com/Gourieff/sd-webui-extensions-updater) to the `ComfyUI\custom_nodes` directory and run it when you need to check for updates
385
+
386
+ ### Disclaimer
387
+
388
+ This software is meant to be a productive contribution to the rapidly growing AI-generated media industry. It will help artists with tasks such as animating a custom character or using the character as a model for clothing etc.
389
+
390
+ The developers of this software are aware of its possible unethical applications and are committed to take preventative measures against them. We will continue to develop this project in the positive direction while adhering to law and ethics.
391
+
392
+ Users of this software are expected to use this software responsibly while abiding the local law. If face of a real person is being used, users are suggested to get consent from the concerned person and clearly mention that it is a deepfake when posting content online. **Developers and Contributors of this software are not responsible for actions of end-users.**
393
+
394
+ By using this extension you are agree not to create any content that:
395
+ - violates any laws;
396
+ - causes any harm to a person or persons;
397
+ - propagates (spreads) any information (both public or personal) or images (both public or personal) which could be meant for harm;
398
+ - spreads misinformation;
399
+ - targets vulnerable groups of people.
400
+
401
+ This software utilizes the pre-trained models `buffalo_l` and `inswapper_128.onnx`, which are provided by [InsightFace](https://github.com/deepinsight/insightface/). These models are included under the following conditions:
402
+
403
+ [From insighface license](https://github.com/deepinsight/insightface/tree/master/python-package): The InsightFace’s pre-trained models are available for non-commercial research purposes only. This includes both auto-downloading models and manually downloaded models.
404
+
405
+ Users of this software must strictly adhere to these conditions of use. The developers and maintainers of this software are not responsible for any misuse of InsightFace’s pre-trained models.
406
+
407
+ Please note that if you intend to use this software for any commercial purposes, you will need to train your own models or find models that can be used commercially.
408
+
409
+ ### Models Hashsum
410
+
411
+ #### Safe-to-use models have the following hash:
412
+
413
+ inswapper_128.onnx
414
+ ```
415
+ MD5:a3a155b90354160350efd66fed6b3d80
416
+ SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
417
+ ```
418
+
419
+ 1k3d68.onnx
420
+
421
+ ```
422
+ MD5:6fb94fcdb0055e3638bf9158e6a108f4
423
+ SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
424
+ ```
425
+
426
+ 2d106det.onnx
427
+
428
+ ```
429
+ MD5:a3613ef9eb3662b4ef88eb90db1fcf26
430
+ SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
431
+ ```
432
+
433
+ det_10g.onnx
434
+
435
+ ```
436
+ MD5:4c10eef5c9e168357a16fdd580fa8371
437
+ SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
438
+ ```
439
+
440
+ genderage.onnx
441
+
442
+ ```
443
+ MD5:81c77ba87ab38163b0dec6b26f8e2af2
444
+ SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
445
+ ```
446
+
447
+ w600k_r50.onnx
448
+
449
+ ```
450
+ MD5:80248d427976241cbd1343889ed132b3
451
+ SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
452
+ ```
453
+
454
+ **Please check hashsums if you download these models from unverified (or untrusted) sources**
455
+
456
+ <a name="credits">
457
+
458
+ ## Thanks and Credits
459
+
460
+ <details>
461
+ <summary><a>Click to expand</a></summary>
462
+
463
+ <br>
464
+
465
+ |file|source|license|
466
+ |----|------|-------|
467
+ |[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
468
+ | [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
469
+ | [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) |
470
+ | [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) |
471
+ | [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
472
+ | [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
473
+
474
+ [BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup) <br>
475
+ [facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao) <br>
476
+
477
+ [@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - the original Roop App <br>
478
+ [@ssitu](https://github.com/ssitu) - the first version of [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop) extension
479
+
480
+ </details>
481
+
482
+ <a name="note">
483
+
484
+ ### Note!
485
+
486
+ **If you encounter any errors when you use ReActor Node - don't rush to open an issue, first try to remove current ReActor node in your workflow and add it again**
487
+
488
+ **ReActor Node gets updates from time to time, new functions appear and old node can work with errors or not work at all**
custom_nodes/ComfyUI-ReActor/README_RU.md ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <img src="https://github.com/Gourieff/Assets/raw/main/sd-webui-reactor/ReActor_logo_NEW_RU.png?raw=true" alt="logo" width="180px"/>
4
+
5
+ ![Version](https://img.shields.io/badge/версия_нода-0.6.0_alpha1-lightgreen?style=for-the-badge&labelColor=darkgreen)
6
+
7
+ <!--<sup>
8
+ <font color=brightred>
9
+
10
+ ## !!! [Важные изменения](#latestupdate) !!!<br>Не забудьте добавить Нод заново в существующие воркфлоу
11
+
12
+ </font>
13
+ </sup>-->
14
+
15
+ <a href="https://boosty.to/artgourieff" target="_blank">
16
+ <img src="https://lovemet.ru/img/boosty.jpg" width="108" alt="Поддержать проект на Boosty"/>
17
+ <br>
18
+ <sup>
19
+ Поддержать проект
20
+ </sup>
21
+ </a>
22
+
23
+ <hr>
24
+
25
+ [![Commit activity](https://img.shields.io/github/commit-activity/t/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/commits/main)
26
+ ![Last commit](https://img.shields.io/github/last-commit/Gourieff/ComfyUI-ReActor/main?cacheSeconds=0)
27
+ [![Opened issues](https://img.shields.io/github/issues/Gourieff/ComfyUI-ReActor?color=red)](https://github.com/Gourieff/ComfyUI-ReActor/issues?cacheSeconds=0)
28
+ [![Closed issues](https://img.shields.io/github/issues-closed/Gourieff/ComfyUI-ReActor?color=green&cacheSeconds=0)](https://github.com/Gourieff/ComfyUI-ReActor/issues?q=is%3Aissue+state%3Aclosed)
29
+ ![License](https://img.shields.io/github/license/Gourieff/ComfyUI-ReActor)
30
+
31
+ [English](/README.md) | Русский
32
+
33
+ # ReActor Nodes для ComfyUI<br><sub><sup>-=Безопасно для работы | SFW-Friendly=-</sup></sub>
34
+
35
+ </div>
36
+
37
+ ### Ноды (nodes) для быстрой и простой замены лиц на любых изображениях для работы с ComfyUI, основан на [ранее заблокированном РеАкторе](https://github.com/Gourieff/comfyui-reactor-node) - теперь имеется встроенный NSFW-детектор, исключающий замену лиц на изображениях с контентом 18+
38
+
39
+ > Используя данное ПО, вы понимаете и принимаете [ответственность](#disclaimer)
40
+
41
+ <div align="center">
42
+
43
+ ---
44
+ [**Что нового**](#latestupdate) | [**Установка**](#installation) | [**Использование**](#usage) | [**Устранение проблем**](#troubleshooting) | [**Обновление**](#updating) | [**Ответственность**](#disclaimer) | [**Благодарности**](#credits) | [**Заметка**](#note)
45
+
46
+ ---
47
+
48
+ </div>
49
+
50
+ <a name="latestupdate">
51
+
52
+ ## Что нового в последнем обновлении
53
+
54
+ ### 0.6.0 <sub><sup>ALPHA1</sup></sub>
55
+
56
+ - Новый нод `ReActorSetWeight` - теперь можно установить силу замены лица для `source_image` или `face_model` от 0% до 100% (с шагом 12.5%)
57
+
58
+ <center>
59
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-01.jpg?raw=true" alt="0.6.0-whatsnew-01" width="100%"/>
60
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-whatsnew-02.jpg?raw=true" alt="0.6.0-whatsnew-02" width="100%"/>
61
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.6.0-alpha1-01.gif?raw=true" alt="0.6.0-whatsnew-03" width="540px"/>
62
+ </center>
63
+
64
+ <details>
65
+ <summary><a>Предыдущие версии</a></summary>
66
+
67
+ ### 0.5.2
68
+
69
+ - Поддержка моделей ReSwapper. Несмотря на то, что Inswapper по-прежнему даёт лучшее сходство, но ReSwapper развивается - спасибо @somanchiu https://github.com/somanchiu/ReSwapper за эти модели и проект ReSwapper! Это хороший шаг для Сообщества в создании альтернативы Инсваппера!
70
+
71
+ <center>
72
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-03.jpg?raw=true" alt="0.5.2-whatsnew-03" width="75%"/>
73
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-04.jpg?raw=true" alt="0.5.2-whatsnew-04" width="75%"/>
74
+ </center>
75
+
76
+ Скачать модели ReSwapper можно отсюда:
77
+ https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models
78
+ Сохраните их в директорию "models/reswapper".
79
+
80
+ - NSFW-детектор, чтобы не нарушать [правила GitHub](https://docs.github.com/en/site-policy/acceptable-use-policies/github-misinformation-and-disinformation#synthetic--manipulated-media-tools)
81
+ - Новый нод "Unload ReActor Models" - полезен для сложных воркфлоу, когда вам нужно освободить ОЗУ, занятую РеАктором
82
+
83
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-01.jpg?raw=true" alt="0.5.2-whatsnew-01" width="100%"/>
84
+
85
+ - Поддержка ORT CoreML and ROCM EPs, достаточно установить ту версию onnxruntime, которая соответствует вашему GPU
86
+ - Некоторые улучшения скрипта установки для поддержки последней версии ORT-GPU
87
+
88
+ <center>
89
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.2-whatsnew-02.jpg?raw=true" alt="0.5.2-whatsnew-02" width="50%"/>
90
+ </center>
91
+
92
+ - Исправления и улучшения
93
+
94
+ ### 0.5.1
95
+
96
+ - Поддержка моделей восстановления лиц GPEN 1024/2048 (доступны в датасете на HF https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models)
97
+ - Нод ReActorFaceBoost - попытка улучшить качество заменённых лиц. Идея состоит в том, чтобы восстановить и увеличить заменённое лицо (в соответствии с параметром `face_size` модели реставрации) ДО того, как лицо будет вставлено в целевое изображения (через алгоритмы инсваппера), больше информации [здесь (PR#321)](https://github.com/Gourieff/comfyui-reactor-node/pull/321)
98
+
99
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-01.jpg?raw=true" alt="0.5.1-whatsnew-01" width="100%"/>
100
+
101
+ [Полноразмерное демо-превью](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.1-whatsnew-02.png)
102
+
103
+ - Сортировка моделей лиц по алфавиту
104
+ - Множество исправлений и улучшений
105
+
106
+ ### [0.5.0 <sub><sup>BETA4</sup></sub>](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.5.0)
107
+
108
+ - Поддержка библиотеки Spandrel при работе с GFPGAN
109
+
110
+ ### 0.5.0 <sub><sup>BETA3</sup></sub>
111
+
112
+ - Исправления: "RAM issue", "No detection" для MaskingHelper
113
+
114
+ ### 0.5.0 <sub><sup>BETA2</sup></sub>
115
+
116
+ - Появилась возможность строить смешанные модели лиц из пачки уже имеющихся моделей - добавьте для этого нод "Make Face Model Batch" в свой воркфлоу и загрузите несколько моделей через ноды "Load Face Model"
117
+ - Огромный буст производительности модуля анализа изображений! 10-кратный прирост скорости! Работа с видео теперь в удовольствие!
118
+
119
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-05.png?raw=true" alt="0.5.0-whatsnew-05" width="100%"/>
120
+
121
+ ### 0.5.0 <sub><sup>BETA1</sup></sub>
122
+
123
+ - Добавлен выход SWAPPED_FACE для нода Masking Helper
124
+ - FIX: Удалён пустой A-канал на выходе IMAGE нода Masking Helper (вызывавший ошибки с некоторым нодами)
125
+
126
+ ### 0.5.0 <sub><sup>ALPHA1</sup></sub>
127
+
128
+ - Нод ReActorBuildFaceModel получил выход "face_model" для отправки совмещенной модели лиц непосредственно в основной Нод:
129
+
130
+ Basic workflow [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v2.json)
131
+
132
+ - Функции маски лица теперь доступна и в версии для Комфи, просто добавьте нод "ReActorMaskHelper" в воркфлоу и соедините узлы, как показано ниже:
133
+
134
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-01.jpg?raw=true" alt="0.5.0-whatsnew-01" width="100%"/>
135
+
136
+ Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox"
137
+ <br>
138
+ То же самое и с ["sam_vit_b_01ec64.pth"](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/sams/sam_vit_b_01ec64.pth) - скачайте (если отсутствует) и положите в папку "ComfyUI\models\sams";
139
+
140
+ Данный нод поможет вам получить куда более аккуратный результат при замене лиц:
141
+
142
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-02.jpg?raw=true" alt="0.5.0-whatsnew-02" width="100%"/>
143
+
144
+ - Нод ReActorImageDublicator - полезен тем, кто создает видео, помогает продублировать одиночное изображение в несколько копий, чтобы использовать их, к при��еру, с VAE энкодером:
145
+
146
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-03.jpg?raw=true" alt="0.5.0-whatsnew-03" width="100%"/>
147
+
148
+ - ReActorFaceSwapOpt (упрощенная версия основного нода) + нод ReActorOptions для установки дополнительных опций, как (новые) "отдельный порядок лиц для input/source". Да! Теперь можно установить любой порядок "чтения" индекса лиц на изображении, в т.ч. от большего к меньшему (по умолчанию)!
149
+
150
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.5.0-whatsnew-04.jpg?raw=true" alt="0.5.0-whatsnew-04" width="100%"/>
151
+
152
+ - Небольшое улучшение скорости анализа целевых изображений (input)
153
+
154
+ ### [0.4.2](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.2)
155
+
156
+ - Добавлена поддержка GPEN-BFR-512 и RestoreFormer_Plus_Plus моделей восстановления лиц
157
+
158
+ Скачать можно здесь: https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
159
+ <br>Добавьте модели в папку `ComfyUI\models\facerestore_models`
160
+
161
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-04.jpg?raw=true" alt="0.4.2-whatsnew-04" width="100%"/>
162
+
163
+ - По многочисленным просьбам появилась возможность строить смешанные модели лиц и в ComfyUI тоже и использовать их с нодом "Load Face Model" Node или в SD WebUI;
164
+
165
+ Экспериментируйте и создавайте новые лица или совмещайте разные лица нужного вам персонажа, чтобы добиться лучшей точности и схожести с оригиналом!
166
+
167
+ Достаточно добавить нод "Make Image Batch" (ImpactPack) на вход нового нода РеАктора и загрузить в пачку необходимые вам изображения для построения смешанной модели:
168
+
169
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-01.jpg?raw=true" alt="0.4.2-whatsnew-01" width="100%"/>
170
+
171
+ Пример результата (на основе лиц 4-х актрис создано новое лицо):
172
+
173
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.2-whatsnew-02.jpg?raw=true" alt="0.4.2-whatsnew-02" width="75%"/>
174
+
175
+ Базовый воркфлоу [💾](https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/workflows/ReActor--Build-Blended-Face-Model--v1.json)
176
+
177
+ ### [0.4.1](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.1)
178
+
179
+ - Поддержка CUDA 12 - не забудьте запустить (Windows) `install.bat` или (Linux/MacOS) `install.py` для используемого Python окружения или попробуйте установить ORT-GPU для CU12 вручную (https://onnxruntime.ai/docs/install/#install-onnx-runtime-gpu-cuda-12x)
180
+ - Исправление Issue https://github.com/Gourieff/comfyui-reactor-node/issues/173
181
+
182
+ - Отдельный Нод для восстаноления лиц (FR https://github.com/Gourieff/comfyui-reactor-node/issues/191), располагается внутри меню ReActor (нод RestoreFace)
183
+ - (Windows) Установка зависимостей теперь может быть выполнена в Python из PATH ОС
184
+ - Разные исправления и улучшения
185
+
186
+ - Face Restore Visibility и CodeFormer Weight (Fidelity) теперь доступны; не забудьте заново добавить Нод в ваших существующих воркфлоу
187
+
188
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.1-whatsnew-01.jpg?raw=true" alt="0.4.1-whatsnew-01" width="100%"/>
189
+
190
+ ### [0.4.0](https://github.com/Gourieff/comfyui-reactor-node/releases/tag/v0.4.0)
191
+
192
+ - Вход "input_image" теперь идёт первым, это даёт возможность корректного байпаса, а также это правильно с точки зрения расположения входов (главный вход - первый);
193
+ - Теперь можно сохранять модели лиц в качестве файлов "safetensors" (`ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете:
194
+
195
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-01.jpg?raw=true" alt="0.4.0-whatsnew-01" width="100%"/>
196
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-02.jpg?raw=true" alt="0.4.0-whatsnew-02" width="100%"/>
197
+
198
+ - Возможность сохранять модели лиц напрямую из изображения:
199
+
200
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/0.4.0-whatsnew-03.jpg?raw=true" alt="0.4.0-whatsnew-03" width="50%"/>
201
+
202
+ - Оба входа опциональны, присоедините один из них в соответствии с вашим воркфлоу; если присоеденены оба - вход `image` имеет приоритет.
203
+ - Различные исправления, делающие это расширение лучше.
204
+
205
+ Спасибо всем, кто находит ошибки, предлагает новые функции и поддерживает данный проект!
206
+
207
+ </details>
208
+
209
+ <a name="installation">
210
+
211
+ ## Установка
212
+
213
+ <details>
214
+ <summary>SD WebUI: <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/">AUTOMATIC1111</a> или <a href="https://github.com/vladmandic/automatic">SD.Next</a></summary>
215
+
216
+ 1. Закройте (остановите) SD-WebUI Сервер, если запущен
217
+ 2. (Для пользователей Windows):
218
+ - Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface)
219
+ - ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile"
220
+ - ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild)
221
+ 3. Перейдите в `extensions\sd-webui-comfyui\ComfyUI\custom_nodes`
222
+ 4. Откройте Консоль или Терминал и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor`
223
+ 5. Перейдите в корневую директорию SD WebUI, откройте Консоль или Терминал и выполните (для пользователей Windows)`.\venv\Scripts\activate` или (для пользователей Linux/MacOS)`venv/bin/activate`
224
+ 6. `python -m pip install -U pip`
225
+ 7. `cd extensions\sd-webui-comfyui\ComfyUI\custom_nodes\ComfyUI-ReActor`
226
+ 8. `python install.py`
227
+ 9. Пожалуйста, дождитесь полного завершения установки
228
+ 10. (Начиная с версии 0.3.0) Скачайте дополнительные модели восстановления лиц (по ссылке ниже) и сохраните их в папку `extensions\sd-webui-comfyui\ComfyUI\models\facerestore_models`:<br>
229
+ https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/facerestore_models
230
+ 11. Запустите SD WebUI и проверьте консоль на сообщение, что ReActor Node работает:
231
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/console_status_running.jpg?raw=true" alt="console_status_running" width="759"/>
232
+
233
+ 12. Перейдите во вкладку ComfyUI и найдите там ReActor Node внутри меню `ReActor` или через поиск:
234
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/webui-demo.png?raw=true" alt="webui-demo" width="100%"/>
235
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/search-demo.png?raw=true" alt="webui-demo" width="1043"/>
236
+
237
+ </details>
238
+
239
+ <details>
240
+ <summary>Портативная версия <a href="https://github.com/comfyanonymous/ComfyUI">ComfyUI</a> для Windows</summary>
241
+
242
+ 1. Сделайте следующее:
243
+ - Установите [Visual Studio 2022](https://visualstudio.microsoft.com/downloads/) (Например, версию Community - этот шаг нужен для правильной компиляции библиотеки Insightface)
244
+ - ИЛИ только [VS C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/), выберите "Desktop Development with C++" в разделе "Workloads -> Desktop & Mobile"
245
+ - ИЛИ если же вы не хотите устанавливать что-либо из вышеуказанного - выполните [данные шаги (раздел. I)](#insightfacebuild)
246
+ 2. Выберите из двух вариантов:
247
+ - (ComfyUI Manager) Откройте ComfyUI Manager, нажвите "Install Custom Nodes", введите "ReActor" в поле "Search" и далее нажмите "Install". После того, как ComfyUI завершит установку, перезагрузите сервер.
248
+ - (Вручную) Перейди��е в `ComfyUI\custom_nodes`, откройте Консоль и выполните `git clone https://github.com/Gourieff/ComfyUI-ReActor`
249
+ 3. Перейдите `ComfyUI\custom_nodes\ComfyUI-ReActor` и запустите `install.bat`, дождитесь окончания установки
250
+ 4. Если модель "face_yolov8m.pt" у вас отсутствует - можете скачать её [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/detection/bbox/face_yolov8m.pt) и положить в папку "ComfyUI\models\ultralytics\bbox"<br>
251
+ То же самое и с "Sams" моделями, скачайте одну или обе [отсюда](https://huggingface.co/datasets/Gourieff/ReActor/tree/main/models/sams) - и положите в папку "ComfyUI\models\sams"
252
+ 5. Запустите ComfyUI и найдите ReActor Node внутри меню `ReActor` или через поиск
253
+
254
+ </details>
255
+
256
+ <a name="usage">
257
+
258
+ ## Использование
259
+
260
+ Вы можете найти ноды ReActor внутри меню `ReActor` или через поиск (достаточно ввести "ReActor" в поисковой строке)
261
+
262
+ Список нодов:
263
+ - ••• Main Nodes •••
264
+ - ReActorFaceSwap (Основной нод)
265
+ - ReActorFaceSwapOpt (Основной нод с доп. входом Options)
266
+ - ReActorOptions (Опции для ReActorFaceSwapOpt)
267
+ - ReActorFaceBoost (Нод Face Booster)
268
+ - ReActorMaskHelper (Masking Helper)
269
+ - ••• Operations with Face Models •••
270
+ - ReActorSaveFaceModel (Save Face Model)
271
+ - ReActorLoadFaceModel (Load Face Model)
272
+ - ReActorBuildFaceModel (Build Blended Face Model)
273
+ - ReActorMakeFaceModelBatch (Make Face Model Batch)
274
+ - ••• Additional Nodes •••
275
+ - ReActorRestoreFace (Face Restoration)
276
+ - ReActorImageDublicator (Dublicate one Image to Images List)
277
+ - ImageRGBA2RGB (Convert RGBA to RGB)
278
+
279
+ Соедините все необходимые слоты (slots) и запустите очередь (query).
280
+
281
+ ### Входы основного Нода
282
+
283
+ - `input_image` - это изображение, на котором надо поменять лицо или лица (целевое изображение, аналог "target image" в версии для SD WebUI);
284
+ - Поддерживаемые ноды: "Load Image", "Load Video" или любые другие ноды предоставляющие изображение в качестве выхода;
285
+ - `source_image` - это изображение с лицом или лицами для замены (изображение-источник, аналог "source image" в версии для SD WebUI);
286
+ - Поддерживаемые ноды: "Load Image" или любые другие ноды с выходом Image(s);
287
+ - `face_model` - это вход для выхода с нода "Load Face Model" или другого нода ReActor для загрузки модели лица (face model или face embedding), которое вы создали ранее через нод "Save Face Model";
288
+ - Поддерживаемые ноды: "Load Face Model", "Build Blended Face Model";
289
+
290
+ ### Выходы основного Нода
291
+
292
+ - `IMAGE` - выход с готовым изображением (результатом);
293
+ - Поддерживаемые ноды: любые ноды с изображением на входе;
294
+ - `FACE_MODEL` - выход, предоставляющий модель лица, построенную в ходе замены;
295
+ - Поддерживаемые ноды: "Save Face Model", "ReActor", "Make Face Model Batch";
296
+
297
+ ### Восстановление лиц
298
+
299
+ Начиная с версии 0.3.0 ReActor Node имеет встроенное восстановление лиц.<br>Скачайте нужные вам модели (см. инструкцию по [Установке](#installation)) и выберите одну из них, чтобы улучшить качество финального лица.
300
+
301
+ ### Индексы Лиц (Face Indexes)
302
+
303
+ По умолчанию ReActor определяет лица на изображении в порядке от "большого" к "малому".<br>Вы можете поменять эту опцию, используя нод ReActorFaceSwapOpt вместе с ReActorOptions.
304
+
305
+ Если вам нужно заменить определенное лицо, вы можете указать индекс для исходного (source, с лицом) и входного (input, где будет замена лица) изображений.
306
+
307
+ Индекс первого обнаруженного лица - 0.
308
+
309
+ Вы можете задать индексы в том порядке, который вам нужен.<br>
310
+ Например: 0,1,2 (для Source); 1,0,2 (для Input).<br>Это означает, что: второе лицо из Input (индекс = 1) будет заменено первым лицом из Source (индекс = 0) и так далее.
311
+
312
+ ### Определение Пола
313
+
314
+ Вы можете обозначить, какой пол нужно определять на изображении.<br>
315
+ ReActor заменит только то лицо, которое удовлетворяет заданному условию.
316
+
317
+ ### Модели Лиц
318
+ Начиная с версии 0.4.0, вы можете сохранять модели лиц как файлы "safetensors" (хранятся в папке `ComfyUI\models\reactor\faces`) и загружать их в ReActor, реализуя разные сценарии использования, а также храня супер легкие модели лиц, которые вы чаще всего используете.
319
+
320
+ Чтобы новые модели появились в списке моделей нода "Load Face Model" - обновите страницу of с ComfyUI.<br>
321
+ (Рекомендую использовать ComfyUI Manager - иначе ваше воркфлоу может быть потеряно после перезагрузки страницы, если вы не сохранили его).
322
+
323
+ <a name="troubleshooting">
324
+
325
+ ## Устранение проблем
326
+
327
+ <a name="insightfacebuild">
328
+
329
+ ### **I. (Для пользователей Windows) Если вы до сих пор не можете установить пакет Insightface по каким-то причинам или же просто не желаете устанавливать Visual Studio или VS C++ Build Tools - сделайте следующее:**
330
+
331
+ 1. (ComfyUI Portable) Находясь в корневой директории, проверьте версию Python:<br>запустите CMD и выполните `python_embeded\python.exe -V`<br>Вы должны увидеть версию или 3.10, или 3.11, или 3.12
332
+ 2. Скачайте готовый пакет Insightface [для версии 3.10](https://github.com/Gourieff/sd-webui-reactor/raw/main/example/insightface-0.7.3-cp310-cp310-win_amd64.whl) или [для 3.11](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp311-cp311-win_amd64.whl) (если на предыдущем шаге вы увидели 3.11) или [для 3.12](https://github.com/Gourieff/Assets/raw/main/Insightface/insightface-0.7.3-cp312-cp312-win_amd64.whl) (если на предыдущем шаге вы увидели 3.12) и сохраните его в корневую директорию stable-diffusion-webui (A1111 или SD.Next) - туда, где лежит файл "webui-user.bat" -ИЛИ- в корневую директорию ComfyUI, если вы используете ComfyUI Portable
333
+ 3. Из корневой директории запустите:
334
+ - (SD WebUI) CMD и `.\venv\Scripts\activate`
335
+ - (ComfyUI Portable) CMD
336
+ 4. Обновите PIP:
337
+ - (SD WebUI) `python -m pip install -U pip`
338
+ - (ComfyUI Portable) `python_embeded\python.exe -m pip install -U pip`
339
+ 5. Затем установите Insightface:
340
+ - (SD WebUI) `pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
341
+ - (ComfyUI Portable) `python_embeded\python.exe -m pip install insightface-0.7.3-cp310-cp310-win_amd64.whl` (для 3.10) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp311-cp311-win_amd64.whl` (для 3.11) или `python_embeded\python.exe -m pip install insightface-0.7.3-cp312-cp312-win_amd64.whl` (for 3.12)
342
+ 6. Готово!
343
+
344
+ ### **II. "AttributeError: 'NoneType' object has no attribute 'get'"**
345
+
346
+ Эта ошибка появляется, если что-то не так с файлом модели `inswapper_128.onnx`
347
+
348
+ Скачайте вручную по ссылке [отсюда](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx)
349
+ и сохраните в директорию `ComfyUI\models\insightface`, заменив имеющийся файл
350
+
351
+ ### **III. "reactor.execute() got an unexpected keyword argument 'reference_image'"**
352
+
353
+ Это означает, что поменялось обозначение входных точек (input points) всвязи с последним обновлением<br>
354
+ Удалите из вашего рабочего пространства имеющийся ReActor Node и добавьте его снова
355
+
356
+ ### **IV. ControlNet Aux Node IMPORT failed - при использовании совместно с нодом ReActor**
357
+
358
+ 1. Закройте или остановите ComfyUI сервер, если он запущен
359
+ 2. Перейдите в корневую папку ComfyUI, откройте консоль CMD и выполните следующее:
360
+ - `python_embeded\python.exe -m pip uninstall -y opencv-python opencv-contrib-python opencv-python-headless`
361
+ - `python_embeded\python.exe -m pip install opencv-python==4.7.0.72`
362
+ 3. Готово!
363
+
364
+ <img src="https://github.com/Gourieff/Assets/blob/main/comfyui-reactor-node/uploads/reactor-w-controlnet.png?raw=true" alt="reactor+controlnet" />
365
+
366
+ ### **V. "ModuleNotFoundError: No module named 'basicsr'" или "subprocess-exited-with-error" при установке пакета future-0.18.3**
367
+
368
+ - Скачайте https://github.com/Gourieff/Assets/raw/main/comfyui-reactor-node/future-0.18.3-py3-none-any.whl<br>
369
+ - Скопируйте файл в корневую папку ComfyUI и выполните в консоли:
370
+
371
+ python_embeded\python.exe -m pip install future-0.18.3-py3-none-any.whl
372
+
373
+ - Затем:
374
+
375
+ python_embeded\python.exe -m pip install basicsr
376
+
377
+ ### **VI. "fatal: fetch-pack: invalid index-pack output" при исполнении команды `git clone`"**
378
+
379
+ Попробуйте клонировать репозиторий с параметром `--depth=1` (только последний коммит):
380
+
381
+ git clone --depth=1 https://github.com/Gourieff/ComfyUI-ReActor
382
+
383
+ Затем вытяните оставшееся (если требуется):
384
+
385
+ git fetch --unshallow
386
+
387
+ <a name="updating">
388
+
389
+ ## Обновление
390
+
391
+ Положите .bat или .sh скрипт из [данного репозитория](https://github.com/Gourieff/sd-webui-extensions-updater) в папку `ComfyUI\custom_nodes` и запустите, когда желаете обновить ComfyUI и Ноды
392
+
393
+ <a name="disclaimer">
394
+
395
+ ## Ответственность
396
+
397
+ Это программное обеспечение призвано стать продуктивным вкладом в быстрорастущую медиаиндустрию на основе генеративных сетей и искусственного интеллекта. Данное ПО поможет художникам в решении таких задач, как анимация собственного персонажа или использование персонажа в качестве модели для одежды и т.д.
398
+
399
+ Разработчики этого программного обеспечения осведомлены о возможных неэтичных применениях и обязуются принять против этого превентивные меры. Мы продолжим развивать этот проект в позитивном направлении, придерживаясь закона и этики.
400
+
401
+ Подразумевается, что пользователи этого программного обеспечения будут использовать его ответственно, соблюдая локальное законодательство. Если используется лицо реального человека, пользователь обязан получить согласие заинтересованного лица и четко указать, что это дипфейк при размещении контента в Интернете. **Разработчики и Со-авторы данного программного обеспечения не несут ответственности за действия конечных пользователей.**
402
+
403
+ Используя данное расширение, вы соглашаетесь не создавать материалы, которые:
404
+ - нарушают какие-либо действующие законы тех или иных государств или международных организаций;
405
+ - причиняют какой-либо вред человеку или лицам;
406
+ - пропагандируют любую информацию (как общедоступную, так и личную) или изображения (как общедоступные, так и личные), которые могут быть направлены на причинение вреда;
407
+ - используются для распространения дезинформации;
408
+ - нацелены на уязвимые группы людей.
409
+
410
+ Данное программное обеспечение использует предварительно обученные модели `buffalo_l` и `inswapper_128.onnx`, представленные разработчиками [InsightFace](https://github.com/deepinsight/insightface/). Эти модели распространяются при следующих условиях:
411
+
412
+ [Перевод из текста лицензии insighface](https://github.com/deepinsight/insightface/tree/master/python-package): Предварительно обученные модели InsightFace доступны только для некоммерческих исследовательс��их целей. Сюда входят как модели с автоматической загрузкой, так и модели, загруженные вручную.
413
+
414
+ Пользователи данного программного обеспечения должны строго соблюдать данные условия использования. Разработчики и Со-авторы данного программного продукта не несут ответственности за неправильное использование предварительно обученных моделей InsightFace.
415
+
416
+ Обратите внимание: если вы собираетесь использовать это программное обеспечение в каких-либо коммерческих целях, вам необходимо будет обучить свои собственные модели или найти модели, которые можно использовать в коммерческих целях.
417
+
418
+ ### Хэш файлов моделей
419
+
420
+ #### Безопасные для использования модели имеют следующий хэш:
421
+
422
+ inswapper_128.onnx
423
+ ```
424
+ MD5:a3a155b90354160350efd66fed6b3d80
425
+ SHA256:e4a3f08c753cb72d04e10aa0f7dbe3deebbf39567d4ead6dce08e98aa49e16af
426
+ ```
427
+
428
+ 1k3d68.onnx
429
+
430
+ ```
431
+ MD5:6fb94fcdb0055e3638bf9158e6a108f4
432
+ SHA256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
433
+ ```
434
+
435
+ 2d106det.onnx
436
+
437
+ ```
438
+ MD5:a3613ef9eb3662b4ef88eb90db1fcf26
439
+ SHA256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
440
+ ```
441
+
442
+ det_10g.onnx
443
+
444
+ ```
445
+ MD5:4c10eef5c9e168357a16fdd580fa8371
446
+ SHA256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
447
+ ```
448
+
449
+ genderage.onnx
450
+
451
+ ```
452
+ MD5:81c77ba87ab38163b0dec6b26f8e2af2
453
+ SHA256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
454
+ ```
455
+
456
+ w600k_r50.onnx
457
+
458
+ ```
459
+ MD5:80248d427976241cbd1343889ed132b3
460
+ SHA256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
461
+ ```
462
+
463
+ **Пожалуйста, сравните хэш, если вы скачиваете данные модели из непроверенных источников**
464
+
465
+ <a name="credits">
466
+
467
+ ## Благодарности и авторы компонентов
468
+
469
+ <details>
470
+ <summary><a>Нажмите, чтобы посмотреть</a></summary>
471
+
472
+ <br>
473
+
474
+ |файл|источник|лицензия|
475
+ |----|--------|--------|
476
+ |[buffalo_l.zip](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/buffalo_l.zip) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
477
+ | [codeformer-v0.1.0.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/codeformer-v0.1.0.pth) | [sczhou](https://github.com/sczhou/CodeFormer) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
478
+ | [GFPGANv1.3.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.3.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) |
479
+ | [GFPGANv1.4.pth](https://huggingface.co/datasets/Gourieff/ReActor/blob/main/models/facerestore_models/GFPGANv1.4.pth) | [TencentARC](https://github.com/TencentARC/GFPGAN) | ![license](https://img.shields.io/badge/license-Apache_2.0-green.svg) |
480
+ | [inswapper_128.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128.onnx) | [DeepInsight](https://github.com/deepinsight/insightface) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
481
+ | [inswapper_128_fp16.onnx](https://github.com/facefusion/facefusion-assets/releases/download/models/inswapper_128_fp16.onnx) | [Hillobar](https://github.com/Hillobar/Rope) | ![license](https://img.shields.io/badge/license-non_commercial-red) |
482
+
483
+ [BasicSR](https://github.com/XPixelGroup/BasicSR) - [@XPixelGroup](https://github.com/XPixelGroup) <br>
484
+ [facexlib](https://github.com/xinntao/facexlib) - [@xinntao](https://github.com/xinntao) <br>
485
+
486
+ [@s0md3v](https://github.com/s0md3v), [@henryruhs](https://github.com/henryruhs) - оригинальное приложение Roop <br>
487
+ [@ssitu](https://github.com/ssitu) - первая версия расширения с поддержкой ComfyUI [ComfyUI_roop](https://github.com/ssitu/ComfyUI_roop)
488
+
489
+ </details>
490
+
491
+ <a name="note">
492
+
493
+ ### Обратите внимание!
494
+
495
+ **Если у вас возникли какие-либо ошибки при очередном использовании Нода ReActor - не торопитесь открывать Issue, для начала попробуйте удалить текущий Нод из вашего рабочего пространства и добавить его снова**
496
+
497
+ **ReActor Node периодически получает обновления, появляются новые функции, из-за чего имеющ��йся Нод может работать с ошибками или не работать вовсе**
custom_nodes/ComfyUI-ReActor/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ repo_dir = os.path.dirname(os.path.realpath(__file__))
5
+ sys.path.insert(0, repo_dir)
6
+ original_modules = sys.modules.copy()
7
+
8
+ # Place aside existing modules if using a1111 web ui
9
+ modules_used = [
10
+ "modules",
11
+ "modules.images",
12
+ "modules.processing",
13
+ "modules.scripts_postprocessing",
14
+ "modules.scripts",
15
+ "modules.shared",
16
+ ]
17
+ original_webui_modules = {}
18
+ for module in modules_used:
19
+ if module in sys.modules:
20
+ original_webui_modules[module] = sys.modules.pop(module)
21
+
22
+ # Proceed with node setup
23
+ from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
24
+
25
+ __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
26
+
27
+ # Clean up imports
28
+ # Remove repo directory from path
29
+ sys.path.remove(repo_dir)
30
+ # Remove any new modules
31
+ modules_to_remove = []
32
+ for module in sys.modules:
33
+ if module not in original_modules and not module.startswith("google.protobuf") and not module.startswith("onnx") and not module.startswith("cv2"):
34
+ modules_to_remove.append(module)
35
+ for module in modules_to_remove:
36
+ del sys.modules[module]
37
+
38
+ # Restore original modules
39
+ sys.modules.update(original_webui_modules)
custom_nodes/ComfyUI-ReActor/install.bat ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal enabledelayedexpansion
3
+
4
+ :: Try to use embedded python first
5
+ if exist ..\..\..\python_embeded\python.exe (
6
+ :: Use the embedded python
7
+ set PYTHON=..\..\..\python_embeded\python.exe
8
+ ) else (
9
+ :: Embedded python not found, check for python in the PATH
10
+ for /f "tokens=* USEBACKQ" %%F in (`python --version 2^>^&1`) do (
11
+ set PYTHON_VERSION=%%F
12
+ )
13
+ if errorlevel 1 (
14
+ echo I couldn't find an embedded version of Python, nor one in the Windows PATH. Please install manually.
15
+ pause
16
+ exit /b 1
17
+ ) else (
18
+ :: Use python from the PATH (if it's the right version and the user agrees)
19
+ echo I couldn't find an embedded version of Python, but I did find !PYTHON_VERSION! in your Windows PATH.
20
+ echo Would you like to proceed with the install using that version? (Y/N^)
21
+ set /p USE_PYTHON=
22
+ if /i "!USE_PYTHON!"=="Y" (
23
+ set PYTHON=python
24
+ ) else (
25
+ echo Okay. Please install manually.
26
+ pause
27
+ exit /b 1
28
+ )
29
+ )
30
+ )
31
+
32
+ :: Install the package
33
+ echo Installing...
34
+ %PYTHON% install.py
35
+ echo Done^!
36
+
37
+ @pause
custom_nodes/ComfyUI-ReActor/install.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
3
+
4
+ import subprocess
5
+ import os, sys
6
+ try:
7
+ from pkg_resources import get_distribution as distributions
8
+ except:
9
+ from importlib_metadata import distributions
10
+ from tqdm import tqdm
11
+ import urllib.request
12
+ from packaging import version as pv
13
+ try:
14
+ from folder_paths import models_dir
15
+ except:
16
+ from pathlib import Path
17
+ models_dir = os.path.join(Path(__file__).parents[2], "models")
18
+
19
+ sys.path.append(os.path.dirname(os.path.realpath(__file__)))
20
+
21
+ req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")
22
+
23
+ model_url = "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/inswapper_128.onnx"
24
+ model_name = os.path.basename(model_url)
25
+ models_dir_path = os.path.join(models_dir, "insightface")
26
+ model_path = os.path.join(models_dir_path, model_name)
27
+
28
+ def run_pip(*args):
29
+ subprocess.run([sys.executable, "-m", "pip", "install", "--no-warn-script-location", *args])
30
+
31
+ def is_installed (
32
+ package: str, version: str = None, strict: bool = True
33
+ ):
34
+ has_package = None
35
+ try:
36
+ has_package = distributions(package)
37
+ if has_package is not None:
38
+ if version is not None:
39
+ installed_version = has_package.version
40
+ if (installed_version != version and strict == True) or (pv.parse(installed_version) < pv.parse(version) and strict == False):
41
+ return False
42
+ else:
43
+ return True
44
+ else:
45
+ return True
46
+ else:
47
+ return False
48
+ except Exception as e:
49
+ print(f"Status: {e}")
50
+ return False
51
+
52
+ def download(url, path, name):
53
+ request = urllib.request.urlopen(url)
54
+ total = int(request.headers.get('Content-Length', 0))
55
+ with tqdm(total=total, desc=f'[ReActor] Downloading {name} to {path}', unit='B', unit_scale=True, unit_divisor=1024) as progress:
56
+ urllib.request.urlretrieve(url, path, reporthook=lambda count, block_size, total_size: progress.update(block_size))
57
+
58
+ if not os.path.exists(models_dir_path):
59
+ os.makedirs(models_dir_path)
60
+
61
+ if not os.path.exists(model_path):
62
+ download(model_url, model_path, model_name)
63
+
64
+ with open(req_file) as file:
65
+ try:
66
+ ort = "onnxruntime-gpu"
67
+ import torch
68
+ cuda_version = None
69
+ if torch.cuda.is_available():
70
+ cuda_version = torch.version.cuda
71
+ print(f"CUDA {cuda_version}")
72
+ elif torch.backends.mps.is_available() or hasattr(torch,'dml') or hasattr(torch,'privateuseone'):
73
+ ort = "onnxruntime"
74
+ if cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ <= "2.2.0": # CU12.x and torch<=2.2.0
75
+ print(f"Torch: {torch.torch_version.__version__}")
76
+ if not is_installed(ort,"1.17.0",False):
77
+ run_pip(ort,"--extra-index-url", "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/")
78
+ elif cuda_version is not None and float(cuda_version)>=12 and torch.torch_version.__version__ >= "2.4.0" : # CU12.x and latest torch
79
+ print(f"Torch: {torch.torch_version.__version__}")
80
+ if not is_installed(ort,"1.20.1",False): # latest ort-gpu
81
+ run_pip(ort,"-U")
82
+ elif not is_installed(ort,"1.16.1",False):
83
+ run_pip(ort, "-U")
84
+ except Exception as e:
85
+ print(e)
86
+ print(f"Warning: Failed to install {ort}, ReActor will not work.")
87
+ raise e
88
+ strict = True
89
+ for package in file:
90
+ package_version = None
91
+ try:
92
+ package = package.strip()
93
+ if "==" in package:
94
+ package_version = package.split('==')[1]
95
+ elif ">=" in package:
96
+ package_version = package.split('>=')[1]
97
+ strict = False
98
+ if not is_installed(package,package_version,strict):
99
+ run_pip(package)
100
+ except Exception as e:
101
+ print(e)
102
+ print(f"Warning: Failed to install {package}, ReActor will not work.")
103
+ raise e
104
+ print("Ok")
custom_nodes/ComfyUI-ReActor/modules/__init__.py ADDED
File without changes
custom_nodes/ComfyUI-ReActor/modules/images.py ADDED
File without changes
custom_nodes/ComfyUI-ReActor/modules/processing.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class StableDiffusionProcessing:
2
+
3
+ def __init__(self, init_imgs):
4
+ self.init_images = init_imgs
5
+ self.width = init_imgs[0].width
6
+ self.height = init_imgs[0].height
7
+ self.extra_generation_params = {}
8
+
9
+
10
+ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
11
+
12
+ def __init__(self, init_img):
13
+ super().__init__(init_img)
custom_nodes/ComfyUI-ReActor/modules/scripts.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class Script:
5
+ pass
6
+
7
+
8
+ def basedir():
9
+ return os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10
+
11
+
12
+ class PostprocessImageArgs:
13
+ pass
custom_nodes/ComfyUI-ReActor/modules/scripts_postprocessing.py ADDED
File without changes
custom_nodes/ComfyUI-ReActor/modules/shared.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Options:
2
+ img2img_background_color = "#ffffff" # Set to white for now
3
+
4
+
5
+ class State:
6
+ interrupted = False
7
+
8
+ def begin(self):
9
+ pass
10
+
11
+ def end(self):
12
+ pass
13
+
14
+
15
+ opts = Options()
16
+ state = State()
17
+ cmd_opts = None
18
+ sd_upscalers = []
19
+ face_restorers = []
custom_nodes/ComfyUI-ReActor/nodes.py ADDED
@@ -0,0 +1,1364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, sys
2
+ import logging
3
+
4
+ import torch
5
+ import torch.nn.functional as torchfn
6
+ from torchvision.transforms.functional import normalize
7
+ from torchvision.ops import masks_to_boxes
8
+
9
+ import numpy as np
10
+ import cv2
11
+ import math
12
+ from typing import List
13
+ from PIL import Image
14
+ from scipy import stats
15
+ from insightface.app.common import Face
16
+ from segment_anything import sam_model_registry
17
+
18
+ from modules.processing import StableDiffusionProcessingImg2Img
19
+ from modules.shared import state
20
+ # from comfy_extras.chainner_models import model_loading
21
+ import comfy.model_management as model_management
22
+ import comfy.utils
23
+ import folder_paths
24
+
25
+ import scripts.reactor_version
26
+ from r_chainner import model_loading
27
+ from scripts.reactor_faceswap import (
28
+ FaceSwapScript,
29
+ get_models,
30
+ get_current_faces_model,
31
+ analyze_faces,
32
+ half_det_size,
33
+ providers
34
+ )
35
+ from scripts.reactor_swapper import (
36
+ unload_all_models,
37
+ )
38
+ from scripts.reactor_logger import logger
39
+ from reactor_utils import (
40
+ batch_tensor_to_pil,
41
+ batched_pil_to_tensor,
42
+ tensor_to_pil,
43
+ img2tensor,
44
+ tensor2img,
45
+ save_face_model,
46
+ load_face_model,
47
+ download,
48
+ set_ort_session,
49
+ prepare_cropped_face,
50
+ normalize_cropped_face,
51
+ add_folder_path_and_extensions,
52
+ rgba2rgb_tensor
53
+ )
54
+ from reactor_patcher import apply_patch
55
+ from r_facelib.utils.face_restoration_helper import FaceRestoreHelper
56
+ from r_basicsr.utils.registry import ARCH_REGISTRY
57
+ import scripts.r_archs.codeformer_arch
58
+ import scripts.r_masking.subcore as subcore
59
+ import scripts.r_masking.core as core
60
+ import scripts.r_masking.segs as masking_segs
61
+
62
+ import scripts.reactor_sfw as sfw
63
+
64
+
65
+ models_dir = folder_paths.models_dir
66
+ REACTOR_MODELS_PATH = os.path.join(models_dir, "reactor")
67
+ FACE_MODELS_PATH = os.path.join(REACTOR_MODELS_PATH, "faces")
68
+ NSFWDET_MODEL_PATH = os.path.join(models_dir, "nsfw_detector","vit-base-nsfw-detector")
69
+
70
+ if not os.path.exists(REACTOR_MODELS_PATH):
71
+ os.makedirs(REACTOR_MODELS_PATH)
72
+ if not os.path.exists(FACE_MODELS_PATH):
73
+ os.makedirs(FACE_MODELS_PATH)
74
+
75
+ dir_facerestore_models = os.path.join(models_dir, "facerestore_models")
76
+ os.makedirs(dir_facerestore_models, exist_ok=True)
77
+ folder_paths.folder_names_and_paths["facerestore_models"] = ([dir_facerestore_models], folder_paths.supported_pt_extensions)
78
+
79
+ BLENDED_FACE_MODEL = None
80
+ FACE_SIZE: int = 512
81
+ FACE_HELPER = None
82
+
83
+ if "ultralytics" not in folder_paths.folder_names_and_paths:
84
+ add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(models_dir, "ultralytics", "bbox")], folder_paths.supported_pt_extensions)
85
+ add_folder_path_and_extensions("ultralytics_segm", [os.path.join(models_dir, "ultralytics", "segm")], folder_paths.supported_pt_extensions)
86
+ add_folder_path_and_extensions("ultralytics", [os.path.join(models_dir, "ultralytics")], folder_paths.supported_pt_extensions)
87
+ if "sams" not in folder_paths.folder_names_and_paths:
88
+ add_folder_path_and_extensions("sams", [os.path.join(models_dir, "sams")], folder_paths.supported_pt_extensions)
89
+
90
+ def get_facemodels():
91
+ models_path = os.path.join(FACE_MODELS_PATH, "*")
92
+ models = glob.glob(models_path)
93
+ models = [x for x in models if x.endswith(".safetensors")]
94
+ return models
95
+
96
+ def get_restorers():
97
+ models_path = os.path.join(models_dir, "facerestore_models/*")
98
+ models = glob.glob(models_path)
99
+ models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))]
100
+ if len(models) == 0:
101
+ fr_urls = [
102
+ "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.3.pth",
103
+ "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GFPGANv1.4.pth",
104
+ "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/codeformer-v0.1.0.pth",
105
+ "https://huggingface.co/datasets/Gourieff/ReActor/resolve/main/models/facerestore_models/GPEN-BFR-512.onnx",
106
+ ]
107
+ for model_url in fr_urls:
108
+ model_name = os.path.basename(model_url)
109
+ model_path = os.path.join(dir_facerestore_models, model_name)
110
+ download(model_url, model_path, model_name)
111
+ models = glob.glob(models_path)
112
+ models = [x for x in models if (x.endswith(".pth") or x.endswith(".onnx"))]
113
+ return models
114
+
115
+ def get_model_names(get_models):
116
+ models = get_models()
117
+ names = []
118
+ for x in models:
119
+ names.append(os.path.basename(x))
120
+ names.sort(key=str.lower)
121
+ names.insert(0, "none")
122
+ return names
123
+
124
+ def model_names():
125
+ models = get_models()
126
+ return {os.path.basename(x): x for x in models}
127
+
128
+
129
+ class reactor:
130
+ @classmethod
131
+ def INPUT_TYPES(s):
132
+ return {
133
+ "required": {
134
+ "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
135
+ "input_image": ("IMAGE",),
136
+ "swap_model": (list(model_names().keys()),),
137
+ "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
138
+ "face_restore_model": (get_model_names(get_restorers),),
139
+ "face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
140
+ "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
141
+ "detect_gender_input": (["no","female","male"], {"default": "no"}),
142
+ "detect_gender_source": (["no","female","male"], {"default": "no"}),
143
+ "input_faces_index": ("STRING", {"default": "0"}),
144
+ "source_faces_index": ("STRING", {"default": "0"}),
145
+ "console_log_level": ([0, 1, 2], {"default": 1}),
146
+ },
147
+ "optional": {
148
+ "source_image": ("IMAGE",),
149
+ "face_model": ("FACE_MODEL",),
150
+ "face_boost": ("FACE_BOOST",),
151
+ },
152
+ "hidden": {"faces_order": "FACES_ORDER"},
153
+ }
154
+
155
+ RETURN_TYPES = ("IMAGE","FACE_MODEL")
156
+ FUNCTION = "execute"
157
+ CATEGORY = "🌌 ReActor"
158
+
159
+ def __init__(self):
160
+ # self.face_helper = None
161
+ self.faces_order = ["large-small", "large-small"]
162
+ # self.face_size = FACE_SIZE
163
+ self.face_boost_enabled = False
164
+ self.restore = True
165
+ self.boost_model = None
166
+ self.interpolation = "Bicubic"
167
+ self.boost_model_visibility = 1
168
+ self.boost_cf_weight = 0.5
169
+
170
+ def restore_face(
171
+ self,
172
+ input_image,
173
+ face_restore_model,
174
+ face_restore_visibility,
175
+ codeformer_weight,
176
+ facedetection,
177
+ ):
178
+
179
+ result = input_image
180
+
181
+ if face_restore_model != "none" and not model_management.processing_interrupted():
182
+
183
+ global FACE_SIZE, FACE_HELPER
184
+
185
+ self.face_helper = FACE_HELPER
186
+
187
+ faceSize = 512
188
+ if "1024" in face_restore_model.lower():
189
+ faceSize = 1024
190
+ elif "2048" in face_restore_model.lower():
191
+ faceSize = 2048
192
+
193
+ logger.status(f"Restoring with {face_restore_model} | Face Size is set to {faceSize}")
194
+
195
+ model_path = folder_paths.get_full_path("facerestore_models", face_restore_model)
196
+
197
+ device = model_management.get_torch_device()
198
+
199
+ if "codeformer" in face_restore_model.lower():
200
+
201
+ codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
202
+ dim_embd=512,
203
+ codebook_size=1024,
204
+ n_head=8,
205
+ n_layers=9,
206
+ connect_list=["32", "64", "128", "256"],
207
+ ).to(device)
208
+ checkpoint = torch.load(model_path)["params_ema"]
209
+ codeformer_net.load_state_dict(checkpoint)
210
+ facerestore_model = codeformer_net.eval()
211
+
212
+ elif ".onnx" in face_restore_model:
213
+
214
+ ort_session = set_ort_session(model_path, providers=providers)
215
+ ort_session_inputs = {}
216
+ facerestore_model = ort_session
217
+
218
+ else:
219
+
220
+ sd = comfy.utils.load_torch_file(model_path, safe_load=True)
221
+ facerestore_model = model_loading.load_state_dict(sd).eval()
222
+ facerestore_model.to(device)
223
+
224
+ if faceSize != FACE_SIZE or self.face_helper is None:
225
+ self.face_helper = FaceRestoreHelper(1, face_size=faceSize, crop_ratio=(1, 1), det_model=facedetection, save_ext='png', use_parse=True, device=device)
226
+ FACE_SIZE = faceSize
227
+ FACE_HELPER = self.face_helper
228
+
229
+ image_np = 255. * result.numpy()
230
+
231
+ total_images = image_np.shape[0]
232
+
233
+ out_images = []
234
+
235
+ for i in range(total_images):
236
+
237
+ if total_images > 1:
238
+ logger.status(f"Restoring {i+1}")
239
+
240
+ cur_image_np = image_np[i,:, :, ::-1]
241
+
242
+ original_resolution = cur_image_np.shape[0:2]
243
+
244
+ if facerestore_model is None or self.face_helper is None:
245
+ return result
246
+
247
+ self.face_helper.clean_all()
248
+ self.face_helper.read_image(cur_image_np)
249
+ self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
250
+ self.face_helper.align_warp_face()
251
+
252
+ restored_face = None
253
+
254
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
255
+
256
+ # if ".pth" in face_restore_model:
257
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
258
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
259
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
260
+
261
+ try:
262
+
263
+ with torch.no_grad():
264
+
265
+ if ".onnx" in face_restore_model: # ONNX models
266
+
267
+ for ort_session_input in ort_session.get_inputs():
268
+ if ort_session_input.name == "input":
269
+ cropped_face_prep = prepare_cropped_face(cropped_face)
270
+ ort_session_inputs[ort_session_input.name] = cropped_face_prep
271
+ if ort_session_input.name == "weight":
272
+ weight = np.array([ 1 ], dtype = np.double)
273
+ ort_session_inputs[ort_session_input.name] = weight
274
+
275
+ output = ort_session.run(None, ort_session_inputs)[0][0]
276
+ restored_face = normalize_cropped_face(output)
277
+
278
+ else: # PTH models
279
+
280
+ output = facerestore_model(cropped_face_t, w=codeformer_weight)[0] if "codeformer" in face_restore_model.lower() else facerestore_model(cropped_face_t)[0]
281
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
282
+
283
+ del output
284
+ torch.cuda.empty_cache()
285
+
286
+ except Exception as error:
287
+
288
+ print(f"\tFailed inference: {error}", file=sys.stderr)
289
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
290
+
291
+ if face_restore_visibility < 1:
292
+ restored_face = cropped_face * (1 - face_restore_visibility) + restored_face * face_restore_visibility
293
+
294
+ restored_face = restored_face.astype("uint8")
295
+ self.face_helper.add_restored_face(restored_face)
296
+
297
+ self.face_helper.get_inverse_affine(None)
298
+
299
+ restored_img = self.face_helper.paste_faces_to_input_image()
300
+ restored_img = restored_img[:, :, ::-1]
301
+
302
+ if original_resolution != restored_img.shape[0:2]:
303
+ restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_AREA)
304
+
305
+ self.face_helper.clean_all()
306
+
307
+ # out_images[i] = restored_img
308
+ out_images.append(restored_img)
309
+
310
+ if state.interrupted or model_management.processing_interrupted():
311
+ logger.status("Interrupted by User")
312
+ return input_image
313
+
314
+ restored_img_np = np.array(out_images).astype(np.float32) / 255.0
315
+ restored_img_tensor = torch.from_numpy(restored_img_np)
316
+
317
+ result = restored_img_tensor
318
+
319
+ return result
320
+
321
+ def execute(self, enabled, input_image, swap_model, detect_gender_source, detect_gender_input, source_faces_index, input_faces_index, console_log_level, face_restore_model,face_restore_visibility, codeformer_weight, facedetection, source_image=None, face_model=None, faces_order=None, face_boost=None):
322
+
323
+ if face_boost is not None:
324
+ self.face_boost_enabled = face_boost["enabled"]
325
+ self.boost_model = face_boost["boost_model"]
326
+ self.interpolation = face_boost["interpolation"]
327
+ self.boost_model_visibility = face_boost["visibility"]
328
+ self.boost_cf_weight = face_boost["codeformer_weight"]
329
+ self.restore = face_boost["restore_with_main_after"]
330
+ else:
331
+ self.face_boost_enabled = False
332
+
333
+ if faces_order is None:
334
+ faces_order = self.faces_order
335
+
336
+ apply_patch(console_log_level)
337
+
338
+ if not enabled:
339
+ return (input_image,face_model)
340
+ elif source_image is None and face_model is None:
341
+ logger.error("Please provide 'source_image' or `face_model`")
342
+ return (input_image,face_model)
343
+
344
+ if face_model == "none":
345
+ face_model = None
346
+
347
+ script = FaceSwapScript()
348
+ pil_images = batch_tensor_to_pil(input_image)
349
+
350
+ # NSFW checker
351
+ logger.status("Checking for any unsafe content")
352
+ pil_images_sfw = []
353
+ tmp_img = "reactor_tmp.png"
354
+ for img in pil_images:
355
+ if state.interrupted or model_management.processing_interrupted():
356
+ logger.status("Interrupted by User")
357
+ break
358
+ img.save(tmp_img)
359
+ if not sfw.nsfw_image(tmp_img, NSFWDET_MODEL_PATH):
360
+ pil_images_sfw.append(img)
361
+ if os.path.exists(tmp_img):
362
+ os.remove(tmp_img)
363
+ pil_images = pil_images_sfw
364
+ # # #
365
+
366
+ if len(pil_images) > 0:
367
+
368
+ if source_image is not None:
369
+ source = tensor_to_pil(source_image)
370
+ else:
371
+ source = None
372
+ p = StableDiffusionProcessingImg2Img(pil_images)
373
+ script.process(
374
+ p=p,
375
+ img=source,
376
+ enable=True,
377
+ source_faces_index=source_faces_index,
378
+ faces_index=input_faces_index,
379
+ model=swap_model,
380
+ swap_in_source=True,
381
+ swap_in_generated=True,
382
+ gender_source=detect_gender_source,
383
+ gender_target=detect_gender_input,
384
+ face_model=face_model,
385
+ faces_order=faces_order,
386
+ # face boost:
387
+ face_boost_enabled=self.face_boost_enabled,
388
+ face_restore_model=self.boost_model,
389
+ face_restore_visibility=self.boost_model_visibility,
390
+ codeformer_weight=self.boost_cf_weight,
391
+ interpolation=self.interpolation,
392
+ )
393
+ result = batched_pil_to_tensor(p.init_images)
394
+
395
+ if face_model is None:
396
+ current_face_model = get_current_faces_model()
397
+ face_model_to_provide = current_face_model[0] if (current_face_model is not None and len(current_face_model) > 0) else face_model
398
+ else:
399
+ face_model_to_provide = face_model
400
+
401
+ if self.restore or not self.face_boost_enabled:
402
+ result = reactor.restore_face(self,result,face_restore_model,face_restore_visibility,codeformer_weight,facedetection)
403
+
404
+ else:
405
+ image_black = Image.new("RGB", (512, 512))
406
+ result = batched_pil_to_tensor([image_black])
407
+ face_model_to_provide = None
408
+
409
+ return (result,face_model_to_provide)
410
+
411
+
412
+ class ReActorPlusOpt:
413
+ @classmethod
414
+ def INPUT_TYPES(s):
415
+ return {
416
+ "required": {
417
+ "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
418
+ "input_image": ("IMAGE",),
419
+ "swap_model": (list(model_names().keys()),),
420
+ "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
421
+ "face_restore_model": (get_model_names(get_restorers),),
422
+ "face_restore_visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
423
+ "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
424
+ },
425
+ "optional": {
426
+ "source_image": ("IMAGE",),
427
+ "face_model": ("FACE_MODEL",),
428
+ "options": ("OPTIONS",),
429
+ "face_boost": ("FACE_BOOST",),
430
+ }
431
+ }
432
+
433
+ RETURN_TYPES = ("IMAGE","FACE_MODEL")
434
+ FUNCTION = "execute"
435
+ CATEGORY = "🌌 ReActor"
436
+
437
+ def __init__(self):
438
+ # self.face_helper = None
439
+ self.faces_order = ["large-small", "large-small"]
440
+ self.detect_gender_input = "no"
441
+ self.detect_gender_source = "no"
442
+ self.input_faces_index = "0"
443
+ self.source_faces_index = "0"
444
+ self.console_log_level = 1
445
+ # self.face_size = 512
446
+ self.face_boost_enabled = False
447
+ self.restore = True
448
+ self.boost_model = None
449
+ self.interpolation = "Bicubic"
450
+ self.boost_model_visibility = 1
451
+ self.boost_cf_weight = 0.5
452
+
453
+ def execute(self, enabled, input_image, swap_model, facedetection, face_restore_model, face_restore_visibility, codeformer_weight, source_image=None, face_model=None, options=None, face_boost=None):
454
+
455
+ if options is not None:
456
+ self.faces_order = [options["input_faces_order"], options["source_faces_order"]]
457
+ self.console_log_level = options["console_log_level"]
458
+ self.detect_gender_input = options["detect_gender_input"]
459
+ self.detect_gender_source = options["detect_gender_source"]
460
+ self.input_faces_index = options["input_faces_index"]
461
+ self.source_faces_index = options["source_faces_index"]
462
+
463
+ if face_boost is not None:
464
+ self.face_boost_enabled = face_boost["enabled"]
465
+ self.restore = face_boost["restore_with_main_after"]
466
+ else:
467
+ self.face_boost_enabled = False
468
+
469
+ result = reactor.execute(
470
+ self,enabled,input_image,swap_model,self.detect_gender_source,self.detect_gender_input,self.source_faces_index,self.input_faces_index,self.console_log_level,face_restore_model,face_restore_visibility,codeformer_weight,facedetection,source_image,face_model,self.faces_order, face_boost=face_boost
471
+ )
472
+
473
+ return result
474
+
475
+
476
+ class LoadFaceModel:
477
+ @classmethod
478
+ def INPUT_TYPES(s):
479
+ return {
480
+ "required": {
481
+ "face_model": (get_model_names(get_facemodels),),
482
+ }
483
+ }
484
+
485
+ RETURN_TYPES = ("FACE_MODEL",)
486
+ FUNCTION = "load_model"
487
+ CATEGORY = "🌌 ReActor"
488
+
489
+ def load_model(self, face_model):
490
+ self.face_model = face_model
491
+ self.face_models_path = FACE_MODELS_PATH
492
+ if self.face_model != "none":
493
+ face_model_path = os.path.join(self.face_models_path, self.face_model)
494
+ out = load_face_model(face_model_path)
495
+ else:
496
+ out = None
497
+ return (out, )
498
+
499
+
500
+ class ReActorWeight:
501
+ @classmethod
502
+ def INPUT_TYPES(s):
503
+ return {
504
+ "required": {
505
+ "input_image": ("IMAGE",),
506
+ "faceswap_weight": (["0%", "12.5%", "25%", "37.5%", "50%", "62.5%", "75%", "87.5%", "100%"], {"default": "50%"}),
507
+ },
508
+ "optional": {
509
+ "source_image": ("IMAGE",),
510
+ "face_model": ("FACE_MODEL",),
511
+ }
512
+ }
513
+
514
+ RETURN_TYPES = ("IMAGE","FACE_MODEL")
515
+ RETURN_NAMES = ("INPUT_IMAGE","FACE_MODEL")
516
+ FUNCTION = "set_weight"
517
+
518
+ OUTPUT_NODE = True
519
+
520
+ CATEGORY = "🌌 ReActor"
521
+
522
+ def set_weight(self, input_image, faceswap_weight, face_model=None, source_image=None):
523
+
524
+ if input_image is None:
525
+ logger.error("Please provide `input_image`")
526
+ return (input_image,None)
527
+
528
+ if source_image is None and face_model is None:
529
+ logger.error("Please provide `source_image` or `face_model`")
530
+ return (input_image,None)
531
+
532
+ weight = float(faceswap_weight.split("%")[0])
533
+
534
+ images = []
535
+ faces = [] if face_model is None else [face_model]
536
+ embeddings = [] if face_model is None else [face_model.embedding]
537
+
538
+ if weight == 0:
539
+ images = [input_image]
540
+ faces = []
541
+ embeddings = []
542
+ elif weight == 100:
543
+ if face_model is None:
544
+ images = [source_image]
545
+ else:
546
+ if weight > 50:
547
+ images = [input_image]
548
+ count = round(100/(100-weight))
549
+ else:
550
+ if face_model is None:
551
+ images = [source_image]
552
+ count = round(100/(weight))
553
+ for i in range(count-1):
554
+ if weight > 50:
555
+ if face_model is None:
556
+ images.append(source_image)
557
+ else:
558
+ faces.append(face_model)
559
+ embeddings.append(face_model.embedding)
560
+ else:
561
+ images.append(input_image)
562
+
563
+ images_list: List[Image.Image] = []
564
+
565
+ apply_patch(1)
566
+
567
+ if len(images) > 0:
568
+
569
+ for image in images:
570
+ img = tensor_to_pil(image)
571
+ images_list.append(img)
572
+
573
+ for image in images_list:
574
+ face = BuildFaceModel.build_face_model(self,image)
575
+ if isinstance(face, str):
576
+ continue
577
+ faces.append(face)
578
+ embeddings.append(face.embedding)
579
+
580
+ if len(faces) > 0:
581
+ blended_embedding = np.mean(embeddings, axis=0)
582
+ blended_face = Face(
583
+ bbox=faces[0].bbox,
584
+ kps=faces[0].kps,
585
+ det_score=faces[0].det_score,
586
+ landmark_3d_68=faces[0].landmark_3d_68,
587
+ pose=faces[0].pose,
588
+ landmark_2d_106=faces[0].landmark_2d_106,
589
+ embedding=blended_embedding,
590
+ gender=faces[0].gender,
591
+ age=faces[0].age
592
+ )
593
+ if blended_face is None:
594
+ no_face_msg = "Something went wrong, please try another set of images"
595
+ logger.error(no_face_msg)
596
+
597
+ return (input_image,blended_face)
598
+
599
+
600
+ class BuildFaceModel:
601
+ def __init__(self):
602
+ self.output_dir = FACE_MODELS_PATH
603
+
604
+ @classmethod
605
+ def INPUT_TYPES(s):
606
+ return {
607
+ "required": {
608
+ "save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
609
+ "send_only": ("BOOLEAN", {"default": False, "label_off": "NO", "label_on": "YES"}),
610
+ "face_model_name": ("STRING", {"default": "default"}),
611
+ "compute_method": (["Mean", "Median", "Mode"], {"default": "Mean"}),
612
+ },
613
+ "optional": {
614
+ "images": ("IMAGE",),
615
+ "face_models": ("FACE_MODEL",),
616
+ }
617
+ }
618
+
619
+ RETURN_TYPES = ("FACE_MODEL",)
620
+ FUNCTION = "blend_faces"
621
+
622
+ OUTPUT_NODE = True
623
+
624
+ CATEGORY = "🌌 ReActor"
625
+
626
+ def build_face_model(self, image: Image.Image, det_size=(640, 640)):
627
+ logging.StreamHandler.terminator = "\n"
628
+ if image is None:
629
+ error_msg = "Please load an Image"
630
+ logger.error(error_msg)
631
+ return error_msg
632
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
633
+ face_model = analyze_faces(image, det_size)
634
+
635
+ if len(face_model) == 0:
636
+ print("")
637
+ det_size_half = half_det_size(det_size)
638
+ face_model = analyze_faces(image, det_size_half)
639
+ if face_model is not None and len(face_model) > 0:
640
+ print("...........................................................", end=" ")
641
+
642
+ if face_model is not None and len(face_model) > 0:
643
+ return face_model[0]
644
+ else:
645
+ no_face_msg = "No face found, please try another image"
646
+ # logger.error(no_face_msg)
647
+ return no_face_msg
648
+
649
+ def blend_faces(self, save_mode, send_only, face_model_name, compute_method, images=None, face_models=None):
650
+ global BLENDED_FACE_MODEL
651
+ blended_face: Face = BLENDED_FACE_MODEL
652
+
653
+ if send_only and blended_face is None:
654
+ send_only = False
655
+
656
+ if (images is not None or face_models is not None) and not send_only:
657
+
658
+ faces = []
659
+ embeddings = []
660
+
661
+ apply_patch(1)
662
+
663
+ if images is not None:
664
+ images_list: List[Image.Image] = batch_tensor_to_pil(images)
665
+
666
+ n = len(images_list)
667
+
668
+ for i,image in enumerate(images_list):
669
+ logging.StreamHandler.terminator = " "
670
+ logger.status(f"Building Face Model {i+1} of {n}...")
671
+ face = self.build_face_model(image)
672
+ if isinstance(face, str):
673
+ logger.error(f"No faces found in image {i+1}, skipping")
674
+ continue
675
+ else:
676
+ print(f"{int(((i+1)/n)*100)}%")
677
+ faces.append(face)
678
+ embeddings.append(face.embedding)
679
+
680
+ elif face_models is not None:
681
+
682
+ n = len(face_models)
683
+
684
+ for i,face_model in enumerate(face_models):
685
+ logging.StreamHandler.terminator = " "
686
+ logger.status(f"Extracting Face Model {i+1} of {n}...")
687
+ face = face_model
688
+ if isinstance(face, str):
689
+ logger.error(f"No faces found for face_model {i+1}, skipping")
690
+ continue
691
+ else:
692
+ print(f"{int(((i+1)/n)*100)}%")
693
+ faces.append(face)
694
+ embeddings.append(face.embedding)
695
+
696
+ logging.StreamHandler.terminator = "\n"
697
+ if len(faces) > 0:
698
+ # compute_method_name = "Mean" if compute_method == 0 else "Median" if compute_method == 1 else "Mode"
699
+ logger.status(f"Blending with Compute Method '{compute_method}'...")
700
+ blended_embedding = np.mean(embeddings, axis=0) if compute_method == "Mean" else np.median(embeddings, axis=0) if compute_method == "Median" else stats.mode(embeddings, axis=0)[0].astype(np.float32)
701
+ blended_face = Face(
702
+ bbox=faces[0].bbox,
703
+ kps=faces[0].kps,
704
+ det_score=faces[0].det_score,
705
+ landmark_3d_68=faces[0].landmark_3d_68,
706
+ pose=faces[0].pose,
707
+ landmark_2d_106=faces[0].landmark_2d_106,
708
+ embedding=blended_embedding,
709
+ gender=faces[0].gender,
710
+ age=faces[0].age
711
+ )
712
+ if blended_face is not None:
713
+ BLENDED_FACE_MODEL = blended_face
714
+ if save_mode:
715
+ face_model_path = os.path.join(FACE_MODELS_PATH, face_model_name + ".safetensors")
716
+ save_face_model(blended_face,face_model_path)
717
+ # done_msg = f"Face model has been saved to '{face_model_path}'"
718
+ # logger.status(done_msg)
719
+ logger.status("--Done!--")
720
+ # return (blended_face,)
721
+ else:
722
+ no_face_msg = "Something went wrong, please try another set of images"
723
+ logger.error(no_face_msg)
724
+ # return (blended_face,)
725
+ # logger.status("--Done!--")
726
+ if images is None and face_models is None:
727
+ logger.error("Please provide `images` or `face_models`")
728
+ return (blended_face,)
729
+
730
+
731
+ class SaveFaceModel:
732
+ def __init__(self):
733
+ self.output_dir = FACE_MODELS_PATH
734
+
735
+ @classmethod
736
+ def INPUT_TYPES(s):
737
+ return {
738
+ "required": {
739
+ "save_mode": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
740
+ "face_model_name": ("STRING", {"default": "default"}),
741
+ "select_face_index": ("INT", {"default": 0, "min": 0}),
742
+ },
743
+ "optional": {
744
+ "image": ("IMAGE",),
745
+ "face_model": ("FACE_MODEL",),
746
+ }
747
+ }
748
+
749
+ RETURN_TYPES = ()
750
+ FUNCTION = "save_model"
751
+
752
+ OUTPUT_NODE = True
753
+
754
+ CATEGORY = "🌌 ReActor"
755
+
756
+ def save_model(self, save_mode, face_model_name, select_face_index, image=None, face_model=None, det_size=(640, 640)):
757
+ if save_mode and image is not None:
758
+ source = tensor_to_pil(image)
759
+ source = cv2.cvtColor(np.array(source), cv2.COLOR_RGB2BGR)
760
+ apply_patch(1)
761
+ logger.status("Building Face Model...")
762
+ face_model_raw = analyze_faces(source, det_size)
763
+ if len(face_model_raw) == 0:
764
+ det_size_half = half_det_size(det_size)
765
+ face_model_raw = analyze_faces(source, det_size_half)
766
+ try:
767
+ face_model = face_model_raw[select_face_index]
768
+ except:
769
+ logger.error("No face(s) found")
770
+ return face_model_name
771
+ logger.status("--Done!--")
772
+ if save_mode and (face_model != "none" or face_model is not None):
773
+ face_model_path = os.path.join(self.output_dir, face_model_name + ".safetensors")
774
+ save_face_model(face_model,face_model_path)
775
+ if image is None and face_model is None:
776
+ logger.error("Please provide `face_model` or `image`")
777
+ return face_model_name
778
+
779
+
780
+ class RestoreFace:
781
+ @classmethod
782
+ def INPUT_TYPES(s):
783
+ return {
784
+ "required": {
785
+ "image": ("IMAGE",),
786
+ "facedetection": (["retinaface_resnet50", "retinaface_mobile0.25", "YOLOv5l", "YOLOv5n"],),
787
+ "model": (get_model_names(get_restorers),),
788
+ "visibility": ("FLOAT", {"default": 1, "min": 0.0, "max": 1, "step": 0.05}),
789
+ "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
790
+ },
791
+ }
792
+
793
+ RETURN_TYPES = ("IMAGE",)
794
+ FUNCTION = "execute"
795
+ CATEGORY = "🌌 ReActor"
796
+
797
+ # def __init__(self):
798
+ # self.face_helper = None
799
+ # self.face_size = 512
800
+
801
+ def execute(self, image, model, visibility, codeformer_weight, facedetection):
802
+ result = reactor.restore_face(self,image,model,visibility,codeformer_weight,facedetection)
803
+ return (result,)
804
+
805
+
806
+ class MaskHelper:
807
+ def __init__(self):
808
+ # self.threshold = 0.5
809
+ # self.dilation = 10
810
+ # self.crop_factor = 3.0
811
+ # self.drop_size = 1
812
+ self.labels = "all"
813
+ self.detailer_hook = None
814
+ self.device_mode = "AUTO"
815
+ self.detection_hint = "center-1"
816
+ # self.sam_dilation = 0
817
+ # self.sam_threshold = 0.93
818
+ # self.bbox_expansion = 0
819
+ # self.mask_hint_threshold = 0.7
820
+ # self.mask_hint_use_negative = "False"
821
+ # self.force_resize_width = 0
822
+ # self.force_resize_height = 0
823
+ # self.resize_behavior = "source_size"
824
+
825
+ @classmethod
826
+ def INPUT_TYPES(s):
827
+ bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("ultralytics_bbox")]
828
+ segms = ["segm/"+x for x in folder_paths.get_filename_list("ultralytics_segm")]
829
+ sam_models = [x for x in folder_paths.get_filename_list("sams") if 'hq' not in x]
830
+ return {
831
+ "required": {
832
+ "image": ("IMAGE",),
833
+ "swapped_image": ("IMAGE",),
834
+ "bbox_model_name": (bboxs + segms, ),
835
+ "bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
836
+ "bbox_dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}),
837
+ "bbox_crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}),
838
+ "bbox_drop_size": ("INT", {"min": 1, "max": 8192, "step": 1, "default": 10}),
839
+ "sam_model_name": (sam_models, ),
840
+ "sam_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}),
841
+ "sam_threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}),
842
+ "bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
843
+ "mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
844
+ "mask_hint_use_negative": (["False", "Small", "Outter"], ),
845
+ "morphology_operation": (["dilate", "erode", "open", "close"],),
846
+ "morphology_distance": ("INT", {"default": 0, "min": 0, "max": 128, "step": 1}),
847
+ "blur_radius": ("INT", {"default": 9, "min": 0, "max": 48, "step": 1}),
848
+ "sigma_factor": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 3., "step": 0.01}),
849
+ },
850
+ "optional": {
851
+ "mask_optional": ("MASK",),
852
+ }
853
+ }
854
+
855
+ RETURN_TYPES = ("IMAGE","MASK","IMAGE","IMAGE")
856
+ RETURN_NAMES = ("IMAGE","MASK","MASK_PREVIEW","SWAPPED_FACE")
857
+ FUNCTION = "execute"
858
+ CATEGORY = "🌌 ReActor"
859
+
860
+ def execute(self, image, swapped_image, bbox_model_name, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, sam_model_name, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative, morphology_operation, morphology_distance, blur_radius, sigma_factor, mask_optional=None):
861
+
862
+ # images = [image[i:i + 1, ...] for i in range(image.shape[0])]
863
+
864
+ images = image
865
+
866
+ if mask_optional is None:
867
+
868
+ bbox_model_path = folder_paths.get_full_path("ultralytics", bbox_model_name)
869
+ bbox_model = subcore.load_yolo(bbox_model_path)
870
+ bbox_detector = subcore.UltraBBoxDetector(bbox_model)
871
+
872
+ segs = bbox_detector.detect(images, bbox_threshold, bbox_dilation, bbox_crop_factor, bbox_drop_size, self.detailer_hook)
873
+
874
+ if isinstance(self.labels, list):
875
+ self.labels = str(self.labels[0])
876
+
877
+ if self.labels is not None and self.labels != '':
878
+ self.labels = self.labels.split(',')
879
+ if len(self.labels) > 0:
880
+ segs, _ = masking_segs.filter(segs, self.labels)
881
+ # segs, _ = masking_segs.filter(segs, "all")
882
+
883
+ sam_modelname = folder_paths.get_full_path("sams", sam_model_name)
884
+
885
+ if 'vit_h' in sam_model_name:
886
+ model_kind = 'vit_h'
887
+ elif 'vit_l' in sam_model_name:
888
+ model_kind = 'vit_l'
889
+ else:
890
+ model_kind = 'vit_b'
891
+
892
+ sam = sam_model_registry[model_kind](checkpoint=sam_modelname)
893
+ size = os.path.getsize(sam_modelname)
894
+ sam.safe_to = core.SafeToGPU(size)
895
+
896
+ device = model_management.get_torch_device()
897
+
898
+ sam.safe_to.to_device(sam, device)
899
+
900
+ sam.is_auto_mode = self.device_mode == "AUTO"
901
+
902
+ combined_mask, _ = core.make_sam_mask_segmented(sam, segs, images, self.detection_hint, sam_dilation, sam_threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative)
903
+
904
+ else:
905
+ combined_mask = mask_optional
906
+
907
+ # *** MASK TO IMAGE ***:
908
+
909
+ mask_image = combined_mask.reshape((-1, 1, combined_mask.shape[-2], combined_mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
910
+
911
+ # *** MASK MORPH ***:
912
+
913
+ mask_image = core.tensor2mask(mask_image)
914
+
915
+ if morphology_operation == "dilate":
916
+ mask_image = self.dilate(mask_image, morphology_distance)
917
+ elif morphology_operation == "erode":
918
+ mask_image = self.erode(mask_image, morphology_distance)
919
+ elif morphology_operation == "open":
920
+ mask_image = self.erode(mask_image, morphology_distance)
921
+ mask_image = self.dilate(mask_image, morphology_distance)
922
+ elif morphology_operation == "close":
923
+ mask_image = self.dilate(mask_image, morphology_distance)
924
+ mask_image = self.erode(mask_image, morphology_distance)
925
+
926
+ # *** MASK BLUR ***:
927
+
928
+ if len(mask_image.size()) == 3:
929
+ mask_image = mask_image.unsqueeze(3)
930
+
931
+ mask_image = mask_image.permute(0, 3, 1, 2)
932
+ kernel_size = blur_radius * 2 + 1
933
+ sigma = sigma_factor * (0.6 * blur_radius - 0.3)
934
+ mask_image_final = self.gaussian_blur(mask_image, kernel_size, sigma).permute(0, 2, 3, 1)
935
+ if mask_image_final.size()[3] == 1:
936
+ mask_image_final = mask_image_final[:, :, :, 0]
937
+
938
+ # *** CUT BY MASK ***:
939
+
940
+ if len(swapped_image.shape) < 4:
941
+ C = 1
942
+ else:
943
+ C = swapped_image.shape[3]
944
+
945
+ # We operate on RGBA to keep the code clean and then convert back after
946
+ swapped_image = core.tensor2rgba(swapped_image)
947
+ mask = core.tensor2mask(mask_image_final)
948
+
949
+ # Scale the mask to be a matching size if it isn't
950
+ B, H, W, _ = swapped_image.shape
951
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:]
952
+ MB, _, _ = mask.shape
953
+
954
+ if MB < B:
955
+ assert(B % MB == 0)
956
+ mask = mask.repeat(B // MB, 1, 1)
957
+
958
+ # masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end
959
+ is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, H * W]), dim=1).values, 0.)
960
+ mask[is_empty,0,0] = 1.
961
+ boxes = masks_to_boxes(mask)
962
+ mask[is_empty,0,0] = 0.
963
+
964
+ min_x = boxes[:,0]
965
+ min_y = boxes[:,1]
966
+ max_x = boxes[:,2]
967
+ max_y = boxes[:,3]
968
+
969
+ width = max_x - min_x + 1
970
+ height = max_y - min_y + 1
971
+
972
+ use_width = int(torch.max(width).item())
973
+ use_height = int(torch.max(height).item())
974
+
975
+ # if self.force_resize_width > 0:
976
+ # use_width = self.force_resize_width
977
+
978
+ # if self.force_resize_height > 0:
979
+ # use_height = self.force_resize_height
980
+
981
+ alpha_mask = torch.ones((B, H, W, 4))
982
+ alpha_mask[:,:,:,3] = mask
983
+
984
+ swapped_image = swapped_image * alpha_mask
985
+
986
+ cutted_image = torch.zeros((B, use_height, use_width, 4))
987
+ for i in range(0, B):
988
+ if not is_empty[i]:
989
+ ymin = int(min_y[i].item())
990
+ ymax = int(max_y[i].item())
991
+ xmin = int(min_x[i].item())
992
+ xmax = int(max_x[i].item())
993
+ single = (swapped_image[i, ymin:ymax+1, xmin:xmax+1,:]).unsqueeze(0)
994
+ resized = torch.nn.functional.interpolate(single.permute(0, 3, 1, 2), size=(use_height, use_width), mode='bicubic').permute(0, 2, 3, 1)
995
+ cutted_image[i] = resized[0]
996
+
997
+ # Preserve our type unless we were previously RGB and added non-opaque alpha due to the mask size
998
+ if C == 1:
999
+ cutted_image = core.tensor2mask(cutted_image)
1000
+ elif C == 3 and torch.min(cutted_image[:,:,:,3]) == 1:
1001
+ cutted_image = core.tensor2rgb(cutted_image)
1002
+
1003
+ # *** PASTE BY MASK ***:
1004
+
1005
+ image_base = core.tensor2rgba(images)
1006
+ image_to_paste = core.tensor2rgba(cutted_image)
1007
+ mask = core.tensor2mask(mask_image_final)
1008
+
1009
+ # Scale the mask to be a matching size if it isn't
1010
+ B, H, W, C = image_base.shape
1011
+ MB = mask.shape[0]
1012
+ PB = image_to_paste.shape[0]
1013
+
1014
+ if B < PB:
1015
+ assert(PB % B == 0)
1016
+ image_base = image_base.repeat(PB // B, 1, 1, 1)
1017
+ B, H, W, C = image_base.shape
1018
+ if MB < B:
1019
+ assert(B % MB == 0)
1020
+ mask = mask.repeat(B // MB, 1, 1)
1021
+ elif B < MB:
1022
+ assert(MB % B == 0)
1023
+ image_base = image_base.repeat(MB // B, 1, 1, 1)
1024
+ if PB < B:
1025
+ assert(B % PB == 0)
1026
+ image_to_paste = image_to_paste.repeat(B // PB, 1, 1, 1)
1027
+
1028
+ mask = torch.nn.functional.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest')[:,0,:,:]
1029
+ MB, MH, MW = mask.shape
1030
+
1031
+ # masks_to_boxes errors if the tensor is all zeros, so we'll add a single pixel and zero it out at the end
1032
+ is_empty = ~torch.gt(torch.max(torch.reshape(mask,[MB, MH * MW]), dim=1).values, 0.)
1033
+ mask[is_empty,0,0] = 1.
1034
+ boxes = masks_to_boxes(mask)
1035
+ mask[is_empty,0,0] = 0.
1036
+
1037
+ min_x = boxes[:,0]
1038
+ min_y = boxes[:,1]
1039
+ max_x = boxes[:,2]
1040
+ max_y = boxes[:,3]
1041
+ mid_x = (min_x + max_x) / 2
1042
+ mid_y = (min_y + max_y) / 2
1043
+
1044
+ target_width = max_x - min_x + 1
1045
+ target_height = max_y - min_y + 1
1046
+
1047
+ result = image_base.detach().clone()
1048
+ face_segment = mask_image_final
1049
+
1050
+ for i in range(0, MB):
1051
+ if is_empty[i]:
1052
+ continue
1053
+ else:
1054
+ image_index = i
1055
+ source_size = image_to_paste.size()
1056
+ SB, SH, SW, _ = image_to_paste.shape
1057
+
1058
+ # Figure out the desired size
1059
+ width = int(target_width[i].item())
1060
+ height = int(target_height[i].item())
1061
+ # if self.resize_behavior == "keep_ratio_fill":
1062
+ # target_ratio = width / height
1063
+ # actual_ratio = SW / SH
1064
+ # if actual_ratio > target_ratio:
1065
+ # width = int(height * actual_ratio)
1066
+ # elif actual_ratio < target_ratio:
1067
+ # height = int(width / actual_ratio)
1068
+ # elif self.resize_behavior == "keep_ratio_fit":
1069
+ # target_ratio = width / height
1070
+ # actual_ratio = SW / SH
1071
+ # if actual_ratio > target_ratio:
1072
+ # height = int(width / actual_ratio)
1073
+ # elif actual_ratio < target_ratio:
1074
+ # width = int(height * actual_ratio)
1075
+ # elif self.resize_behavior == "source_size" or self.resize_behavior == "source_size_unmasked":
1076
+
1077
+ width = SW
1078
+ height = SH
1079
+
1080
+ # Resize the image we're pasting if needed
1081
+ resized_image = image_to_paste[i].unsqueeze(0)
1082
+ # if SH != height or SW != width:
1083
+ # resized_image = torch.nn.functional.interpolate(resized_image.permute(0, 3, 1, 2), size=(height,width), mode='bicubic').permute(0, 2, 3, 1)
1084
+
1085
+ pasting = torch.ones([H, W, C])
1086
+ ymid = float(mid_y[i].item())
1087
+ ymin = int(math.floor(ymid - height / 2)) + 1
1088
+ ymax = int(math.floor(ymid + height / 2)) + 1
1089
+ xmid = float(mid_x[i].item())
1090
+ xmin = int(math.floor(xmid - width / 2)) + 1
1091
+ xmax = int(math.floor(xmid + width / 2)) + 1
1092
+
1093
+ _, source_ymax, source_xmax, _ = resized_image.shape
1094
+ source_ymin, source_xmin = 0, 0
1095
+
1096
+ if xmin < 0:
1097
+ source_xmin = abs(xmin)
1098
+ xmin = 0
1099
+ if ymin < 0:
1100
+ source_ymin = abs(ymin)
1101
+ ymin = 0
1102
+ if xmax > W:
1103
+ source_xmax -= (xmax - W)
1104
+ xmax = W
1105
+ if ymax > H:
1106
+ source_ymax -= (ymax - H)
1107
+ ymax = H
1108
+
1109
+ pasting[ymin:ymax, xmin:xmax, :] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, :]
1110
+ pasting[:, :, 3] = 1.
1111
+
1112
+ pasting_alpha = torch.zeros([H, W])
1113
+ pasting_alpha[ymin:ymax, xmin:xmax] = resized_image[0, source_ymin:source_ymax, source_xmin:source_xmax, 3]
1114
+
1115
+ # if self.resize_behavior == "keep_ratio_fill" or self.resize_behavior == "source_size_unmasked":
1116
+ # # If we explicitly want to fill the area, we are ok with extending outside
1117
+ # paste_mask = pasting_alpha.unsqueeze(2).repeat(1, 1, 4)
1118
+ # else:
1119
+ # paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4)
1120
+ paste_mask = torch.min(pasting_alpha, mask[i]).unsqueeze(2).repeat(1, 1, 4)
1121
+ result[image_index] = pasting * paste_mask + result[image_index] * (1. - paste_mask)
1122
+
1123
+ face_segment = result
1124
+
1125
+ face_segment[...,3] = mask[i]
1126
+
1127
+ result = rgba2rgb_tensor(result)
1128
+
1129
+ return (result,combined_mask,mask_image_final,face_segment,)
1130
+
1131
+ def gaussian_blur(self, image, kernel_size, sigma):
1132
+ kernel = torch.Tensor(kernel_size, kernel_size).to(device=image.device)
1133
+ center = kernel_size // 2
1134
+ variance = sigma**2
1135
+ for i in range(kernel_size):
1136
+ for j in range(kernel_size):
1137
+ x = i - center
1138
+ y = j - center
1139
+ kernel[i, j] = math.exp(-(x**2 + y**2)/(2*variance))
1140
+ kernel /= kernel.sum()
1141
+
1142
+ # Pad the input tensor
1143
+ padding = (kernel_size - 1) // 2
1144
+ input_pad = torch.nn.functional.pad(image, (padding, padding, padding, padding), mode='reflect')
1145
+
1146
+ # Reshape the padded input tensor for batched convolution
1147
+ batch_size, num_channels, height, width = image.shape
1148
+ input_reshaped = input_pad.reshape(batch_size*num_channels, 1, height+padding*2, width+padding*2)
1149
+
1150
+ # Perform batched convolution with the Gaussian kernel
1151
+ output_reshaped = torch.nn.functional.conv2d(input_reshaped, kernel.unsqueeze(0).unsqueeze(0))
1152
+
1153
+ # Reshape the output tensor to its original shape
1154
+ output_tensor = output_reshaped.reshape(batch_size, num_channels, height, width)
1155
+
1156
+ return output_tensor
1157
+
1158
+ def erode(self, image, distance):
1159
+ return 1. - self.dilate(1. - image, distance)
1160
+
1161
+ def dilate(self, image, distance):
1162
+ kernel_size = 1 + distance * 2
1163
+ # Add the channels dimension
1164
+ image = image.unsqueeze(1)
1165
+ out = torchfn.max_pool2d(image, kernel_size=kernel_size, stride=1, padding=kernel_size // 2).squeeze(1)
1166
+ return out
1167
+
1168
+
1169
+ class ImageDublicator:
1170
+ @classmethod
1171
+ def INPUT_TYPES(s):
1172
+ return {
1173
+ "required": {
1174
+ "image": ("IMAGE",),
1175
+ "count": ("INT", {"default": 1, "min": 0}),
1176
+ },
1177
+ }
1178
+
1179
+ RETURN_TYPES = ("IMAGE",)
1180
+ RETURN_NAMES = ("IMAGES",)
1181
+ OUTPUT_IS_LIST = (True,)
1182
+ FUNCTION = "execute"
1183
+ CATEGORY = "🌌 ReActor"
1184
+
1185
+ def execute(self, image, count):
1186
+ images = [image for i in range(count)]
1187
+ return (images,)
1188
+
1189
+
1190
+ class ImageRGBA2RGB:
1191
+ @classmethod
1192
+ def INPUT_TYPES(s):
1193
+ return {
1194
+ "required": {
1195
+ "image": ("IMAGE",),
1196
+ },
1197
+ }
1198
+
1199
+ RETURN_TYPES = ("IMAGE",)
1200
+ FUNCTION = "execute"
1201
+ CATEGORY = "🌌 ReActor"
1202
+
1203
+ def execute(self, image):
1204
+ out = rgba2rgb_tensor(image)
1205
+ return (out,)
1206
+
1207
+
1208
+ class MakeFaceModelBatch:
1209
+ @classmethod
1210
+ def INPUT_TYPES(s):
1211
+ return {
1212
+ "required": {
1213
+ "face_model1": ("FACE_MODEL",),
1214
+ },
1215
+ "optional": {
1216
+ "face_model2": ("FACE_MODEL",),
1217
+ "face_model3": ("FACE_MODEL",),
1218
+ "face_model4": ("FACE_MODEL",),
1219
+ "face_model5": ("FACE_MODEL",),
1220
+ "face_model6": ("FACE_MODEL",),
1221
+ "face_model7": ("FACE_MODEL",),
1222
+ "face_model8": ("FACE_MODEL",),
1223
+ "face_model9": ("FACE_MODEL",),
1224
+ "face_model10": ("FACE_MODEL",),
1225
+ },
1226
+ }
1227
+
1228
+ RETURN_TYPES = ("FACE_MODEL",)
1229
+ RETURN_NAMES = ("FACE_MODELS",)
1230
+ FUNCTION = "execute"
1231
+
1232
+ CATEGORY = "🌌 ReActor"
1233
+
1234
+ def execute(self, **kwargs):
1235
+ if len(kwargs) > 0:
1236
+ face_models = [value for value in kwargs.values()]
1237
+ return (face_models,)
1238
+ else:
1239
+ logger.error("Please provide at least 1 `face_model`")
1240
+ return (None,)
1241
+
1242
+
1243
+ class ReActorOptions:
1244
+ @classmethod
1245
+ def INPUT_TYPES(s):
1246
+ return {
1247
+ "required": {
1248
+ "input_faces_order": (
1249
+ ["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"}
1250
+ ),
1251
+ "input_faces_index": ("STRING", {"default": "0"}),
1252
+ "detect_gender_input": (["no","female","male"], {"default": "no"}),
1253
+ "source_faces_order": (
1254
+ ["left-right","right-left","top-bottom","bottom-top","small-large","large-small"], {"default": "large-small"}
1255
+ ),
1256
+ "source_faces_index": ("STRING", {"default": "0"}),
1257
+ "detect_gender_source": (["no","female","male"], {"default": "no"}),
1258
+ "console_log_level": ([0, 1, 2], {"default": 1}),
1259
+ }
1260
+ }
1261
+
1262
+ RETURN_TYPES = ("OPTIONS",)
1263
+ FUNCTION = "execute"
1264
+ CATEGORY = "🌌 ReActor"
1265
+
1266
+ def execute(self,input_faces_order, input_faces_index, detect_gender_input, source_faces_order, source_faces_index, detect_gender_source, console_log_level):
1267
+ options: dict = {
1268
+ "input_faces_order": input_faces_order,
1269
+ "input_faces_index": input_faces_index,
1270
+ "detect_gender_input": detect_gender_input,
1271
+ "source_faces_order": source_faces_order,
1272
+ "source_faces_index": source_faces_index,
1273
+ "detect_gender_source": detect_gender_source,
1274
+ "console_log_level": console_log_level,
1275
+ }
1276
+ return (options, )
1277
+
1278
+
1279
+ class ReActorFaceBoost:
1280
+ @classmethod
1281
+ def INPUT_TYPES(s):
1282
+ return {
1283
+ "required": {
1284
+ "enabled": ("BOOLEAN", {"default": True, "label_off": "OFF", "label_on": "ON"}),
1285
+ "boost_model": (get_model_names(get_restorers),),
1286
+ "interpolation": (["Nearest","Bilinear","Bicubic","Lanczos"], {"default": "Bicubic"}),
1287
+ "visibility": ("FLOAT", {"default": 1, "min": 0.1, "max": 1, "step": 0.05}),
1288
+ "codeformer_weight": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1, "step": 0.05}),
1289
+ "restore_with_main_after": ("BOOLEAN", {"default": False}),
1290
+ }
1291
+ }
1292
+
1293
+ RETURN_TYPES = ("FACE_BOOST",)
1294
+ FUNCTION = "execute"
1295
+ CATEGORY = "🌌 ReActor"
1296
+
1297
+ def execute(self,enabled,boost_model,interpolation,visibility,codeformer_weight,restore_with_main_after):
1298
+ face_boost: dict = {
1299
+ "enabled": enabled,
1300
+ "boost_model": boost_model,
1301
+ "interpolation": interpolation,
1302
+ "visibility": visibility,
1303
+ "codeformer_weight": codeformer_weight,
1304
+ "restore_with_main_after": restore_with_main_after,
1305
+ }
1306
+ return (face_boost, )
1307
+
1308
+ class ReActorUnload:
1309
+ @classmethod
1310
+ def INPUT_TYPES(s):
1311
+ return {
1312
+ "required": {
1313
+ "trigger": ("IMAGE", ),
1314
+ },
1315
+ }
1316
+
1317
+ RETURN_TYPES = ("IMAGE",)
1318
+ FUNCTION = "execute"
1319
+ CATEGORY = "🌌 ReActor"
1320
+
1321
+ def execute(self, trigger):
1322
+ unload_all_models()
1323
+ return (trigger,)
1324
+
1325
+
1326
+ NODE_CLASS_MAPPINGS = {
1327
+ # --- MAIN NODES ---
1328
+ "ReActorFaceSwap": reactor,
1329
+ "ReActorFaceSwapOpt": ReActorPlusOpt,
1330
+ "ReActorOptions": ReActorOptions,
1331
+ "ReActorFaceBoost": ReActorFaceBoost,
1332
+ "ReActorMaskHelper": MaskHelper,
1333
+ "ReActorSetWeight": ReActorWeight,
1334
+ # --- Operations with Face Models ---
1335
+ "ReActorSaveFaceModel": SaveFaceModel,
1336
+ "ReActorLoadFaceModel": LoadFaceModel,
1337
+ "ReActorBuildFaceModel": BuildFaceModel,
1338
+ "ReActorMakeFaceModelBatch": MakeFaceModelBatch,
1339
+ # --- Additional Nodes ---
1340
+ "ReActorRestoreFace": RestoreFace,
1341
+ "ReActorImageDublicator": ImageDublicator,
1342
+ "ImageRGBA2RGB": ImageRGBA2RGB,
1343
+ "ReActorUnload": ReActorUnload,
1344
+ }
1345
+
1346
+ NODE_DISPLAY_NAME_MAPPINGS = {
1347
+ # --- MAIN NODES ---
1348
+ "ReActorFaceSwap": "ReActor 🌌 Fast Face Swap",
1349
+ "ReActorFaceSwapOpt": "ReActor 🌌 Fast Face Swap [OPTIONS]",
1350
+ "ReActorOptions": "ReActor 🌌 Options",
1351
+ "ReActorFaceBoost": "ReActor 🌌 Face Booster",
1352
+ "ReActorMaskHelper": "ReActor 🌌 Masking Helper",
1353
+ "ReActorSetWeight": "ReActor 🌌 Set Face Swap Weight",
1354
+ # --- Operations with Face Models ---
1355
+ "ReActorSaveFaceModel": "Save Face Model 🌌 ReActor",
1356
+ "ReActorLoadFaceModel": "Load Face Model 🌌 ReActor",
1357
+ "ReActorBuildFaceModel": "Build Blended Face Model 🌌 ReActor",
1358
+ "ReActorMakeFaceModelBatch": "Make Face Model Batch 🌌 ReActor",
1359
+ # --- Additional Nodes ---
1360
+ "ReActorRestoreFace": "Restore Face 🌌 ReActor",
1361
+ "ReActorImageDublicator": "Image Dublicator (List) 🌌 ReActor",
1362
+ "ImageRGBA2RGB": "Convert RGBA to RGB 🌌 ReActor",
1363
+ "ReActorUnload": "Unload ReActor Models 🌌 ReActor",
1364
+ }
custom_nodes/ComfyUI-ReActor/pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "comfyui-reactor"
3
+ description = "(SFW-Friendly) The Fast and Simple Face Swap Extension Node for ComfyUI, based on ReActor SD-WebUI Face Swap Extension"
4
+ version = "0.6.0-a1"
5
+ license = { file = "LICENSE" }
6
+ dependencies = ["insightface==0.7.3", "onnx>=1.14.0", "opencv-python>=4.7.0.72", "numpy==1.26.3", "segment_anything", "albumentations>=1.4.16", "ultralytics"]
7
+
8
+ [project.urls]
9
+ Repository = "https://github.com/Gourieff/ComfyUI-ReActor"
10
+ # Used by Comfy Registry https://comfyregistry.org
11
+
12
+ [tool.comfy]
13
+ PublisherId = "gourieff"
14
+ DisplayName = "ComfyUI-ReActor"
15
+ Icon = ""
custom_nodes/ComfyUI-ReActor/r_basicsr/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ from .ops import *
9
+ from .test import *
10
+ from .train import *
11
+ from .utils import *
12
+ from .version import __gitsha__, __version__
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from r_basicsr.utils import get_root_logger, scandir
6
+ from r_basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'r_basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/arch_util.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import torchvision
5
+ import warnings
6
+ try:
7
+ from distutils.version import LooseVersion
8
+ except:
9
+ from packaging.version import Version
10
+ LooseVersion = Version
11
+ from itertools import repeat
12
+ from torch import nn as nn
13
+ from torch.nn import functional as F
14
+ from torch.nn import init as init
15
+ from torch.nn.modules.batchnorm import _BatchNorm
16
+
17
+ from r_basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
18
+ from r_basicsr.utils import get_root_logger
19
+
20
+
21
+ @torch.no_grad()
22
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
23
+ """Initialize network weights.
24
+
25
+ Args:
26
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
27
+ scale (float): Scale initialized weights, especially for residual
28
+ blocks. Default: 1.
29
+ bias_fill (float): The value to fill bias. Default: 0
30
+ kwargs (dict): Other arguments for initialization function.
31
+ """
32
+ if not isinstance(module_list, list):
33
+ module_list = [module_list]
34
+ for module in module_list:
35
+ for m in module.modules():
36
+ if isinstance(m, nn.Conv2d):
37
+ init.kaiming_normal_(m.weight, **kwargs)
38
+ m.weight.data *= scale
39
+ if m.bias is not None:
40
+ m.bias.data.fill_(bias_fill)
41
+ elif isinstance(m, nn.Linear):
42
+ init.kaiming_normal_(m.weight, **kwargs)
43
+ m.weight.data *= scale
44
+ if m.bias is not None:
45
+ m.bias.data.fill_(bias_fill)
46
+ elif isinstance(m, _BatchNorm):
47
+ init.constant_(m.weight, 1)
48
+ if m.bias is not None:
49
+ m.bias.data.fill_(bias_fill)
50
+
51
+
52
+ def make_layer(basic_block, num_basic_block, **kwarg):
53
+ """Make layers by stacking the same blocks.
54
+
55
+ Args:
56
+ basic_block (nn.module): nn.module class for basic block.
57
+ num_basic_block (int): number of blocks.
58
+
59
+ Returns:
60
+ nn.Sequential: Stacked blocks in nn.Sequential.
61
+ """
62
+ layers = []
63
+ for _ in range(num_basic_block):
64
+ layers.append(basic_block(**kwarg))
65
+ return nn.Sequential(*layers)
66
+
67
+
68
+ class ResidualBlockNoBN(nn.Module):
69
+ """Residual block without BN.
70
+
71
+ It has a style of:
72
+ ---Conv-ReLU-Conv-+-
73
+ |________________|
74
+
75
+ Args:
76
+ num_feat (int): Channel number of intermediate features.
77
+ Default: 64.
78
+ res_scale (float): Residual scale. Default: 1.
79
+ pytorch_init (bool): If set to True, use pytorch default init,
80
+ otherwise, use default_init_weights. Default: False.
81
+ """
82
+
83
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
84
+ super(ResidualBlockNoBN, self).__init__()
85
+ self.res_scale = res_scale
86
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
87
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
88
+ self.relu = nn.ReLU(inplace=True)
89
+
90
+ if not pytorch_init:
91
+ default_init_weights([self.conv1, self.conv2], 0.1)
92
+
93
+ def forward(self, x):
94
+ identity = x
95
+ out = self.conv2(self.relu(self.conv1(x)))
96
+ return identity + out * self.res_scale
97
+
98
+
99
+ class Upsample(nn.Sequential):
100
+ """Upsample module.
101
+
102
+ Args:
103
+ scale (int): Scale factor. Supported scales: 2^n and 3.
104
+ num_feat (int): Channel number of intermediate features.
105
+ """
106
+
107
+ def __init__(self, scale, num_feat):
108
+ m = []
109
+ if (scale & (scale - 1)) == 0: # scale = 2^n
110
+ for _ in range(int(math.log(scale, 2))):
111
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
112
+ m.append(nn.PixelShuffle(2))
113
+ elif scale == 3:
114
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
115
+ m.append(nn.PixelShuffle(3))
116
+ else:
117
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
118
+ super(Upsample, self).__init__(*m)
119
+
120
+
121
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
122
+ """Warp an image or feature map with optical flow.
123
+
124
+ Args:
125
+ x (Tensor): Tensor with size (n, c, h, w).
126
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
127
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
128
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
129
+ Default: 'zeros'.
130
+ align_corners (bool): Before pytorch 1.3, the default value is
131
+ align_corners=True. After pytorch 1.3, the default value is
132
+ align_corners=False. Here, we use the True as default.
133
+
134
+ Returns:
135
+ Tensor: Warped image or feature map.
136
+ """
137
+ assert x.size()[-2:] == flow.size()[1:3]
138
+ _, _, h, w = x.size()
139
+ # create mesh grid
140
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
141
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
142
+ grid.requires_grad = False
143
+
144
+ vgrid = grid + flow
145
+ # scale grid to [-1,1]
146
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
147
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
148
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
149
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
150
+
151
+ # TODO, what if align_corners=False
152
+ return output
153
+
154
+
155
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
156
+ """Resize a flow according to ratio or shape.
157
+
158
+ Args:
159
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
160
+ size_type (str): 'ratio' or 'shape'.
161
+ sizes (list[int | float]): the ratio for resizing or the final output
162
+ shape.
163
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
164
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
165
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
166
+ ratio > 1.0).
167
+ 2) The order of output_size should be [out_h, out_w].
168
+ interp_mode (str): The mode of interpolation for resizing.
169
+ Default: 'bilinear'.
170
+ align_corners (bool): Whether align corners. Default: False.
171
+
172
+ Returns:
173
+ Tensor: Resized flow.
174
+ """
175
+ _, _, flow_h, flow_w = flow.size()
176
+ if size_type == 'ratio':
177
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
178
+ elif size_type == 'shape':
179
+ output_h, output_w = sizes[0], sizes[1]
180
+ else:
181
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
182
+
183
+ input_flow = flow.clone()
184
+ ratio_h = output_h / flow_h
185
+ ratio_w = output_w / flow_w
186
+ input_flow[:, 0, :, :] *= ratio_w
187
+ input_flow[:, 1, :, :] *= ratio_h
188
+ resized_flow = F.interpolate(
189
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
190
+ return resized_flow
191
+
192
+
193
+ # TODO: may write a cpp file
194
+ def pixel_unshuffle(x, scale):
195
+ """ Pixel unshuffle.
196
+
197
+ Args:
198
+ x (Tensor): Input feature with shape (b, c, hh, hw).
199
+ scale (int): Downsample ratio.
200
+
201
+ Returns:
202
+ Tensor: the pixel unshuffled feature.
203
+ """
204
+ b, c, hh, hw = x.size()
205
+ out_channel = c * (scale**2)
206
+ assert hh % scale == 0 and hw % scale == 0
207
+ h = hh // scale
208
+ w = hw // scale
209
+ x_view = x.view(b, c, h, scale, w, scale)
210
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
211
+
212
+
213
+ class DCNv2Pack(ModulatedDeformConvPack):
214
+ """Modulated deformable conv for deformable alignment.
215
+
216
+ Different from the official DCNv2Pack, which generates offsets and masks
217
+ from the preceding features, this DCNv2Pack takes another different
218
+ features to generate offsets and masks.
219
+
220
+ Ref:
221
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
222
+ """
223
+
224
+ def forward(self, x, feat):
225
+ out = self.conv_offset(feat)
226
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
227
+ offset = torch.cat((o1, o2), dim=1)
228
+ mask = torch.sigmoid(mask)
229
+
230
+ offset_absmean = torch.mean(torch.abs(offset))
231
+ if offset_absmean > 50:
232
+ logger = get_root_logger()
233
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
234
+
235
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
236
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
237
+ self.dilation, mask)
238
+ else:
239
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
240
+ self.dilation, self.groups, self.deformable_groups)
241
+
242
+
243
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
244
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
245
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
246
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
247
+ def norm_cdf(x):
248
+ # Computes standard normal cumulative distribution function
249
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
250
+
251
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
252
+ warnings.warn(
253
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
254
+ 'The distribution of values may be incorrect.',
255
+ stacklevel=2)
256
+
257
+ with torch.no_grad():
258
+ # Values are generated by using a truncated uniform distribution and
259
+ # then using the inverse CDF for the normal distribution.
260
+ # Get upper and lower cdf values
261
+ low = norm_cdf((a - mean) / std)
262
+ up = norm_cdf((b - mean) / std)
263
+
264
+ # Uniformly fill tensor with values from [low, up], then translate to
265
+ # [2l-1, 2u-1].
266
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
267
+
268
+ # Use inverse cdf transform for normal distribution to get truncated
269
+ # standard normal
270
+ tensor.erfinv_()
271
+
272
+ # Transform to proper mean, std
273
+ tensor.mul_(std * math.sqrt(2.))
274
+ tensor.add_(mean)
275
+
276
+ # Clamp to ensure it's in the proper range
277
+ tensor.clamp_(min=a, max=b)
278
+ return tensor
279
+
280
+
281
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
282
+ r"""Fills the input Tensor with values drawn from a truncated
283
+ normal distribution.
284
+
285
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
286
+
287
+ The values are effectively drawn from the
288
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
289
+ with values outside :math:`[a, b]` redrawn until they are within
290
+ the bounds. The method used for generating the random values works
291
+ best when :math:`a \leq \text{mean} \leq b`.
292
+
293
+ Args:
294
+ tensor: an n-dimensional `torch.Tensor`
295
+ mean: the mean of the normal distribution
296
+ std: the standard deviation of the normal distribution
297
+ a: the minimum cutoff value
298
+ b: the maximum cutoff value
299
+
300
+ Examples:
301
+ >>> w = torch.empty(3, 5)
302
+ >>> nn.init.trunc_normal_(w)
303
+ """
304
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
305
+
306
+
307
+ # From PyTorch
308
+ def _ntuple(n):
309
+
310
+ def parse(x):
311
+ if isinstance(x, collections.abc.Iterable):
312
+ return x
313
+ return tuple(repeat(x, n))
314
+
315
+ return parse
316
+
317
+
318
+ to_1tuple = _ntuple(1)
319
+ to_2tuple = _ntuple(2)
320
+ to_3tuple = _ntuple(3)
321
+ to_4tuple = _ntuple(4)
322
+ to_ntuple = _ntuple
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsr_arch.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
7
+ from .edvr_arch import PCDAlignment, TSAFusion
8
+ from .spynet_arch import SpyNet
9
+
10
+
11
+ @ARCH_REGISTRY.register()
12
+ class BasicVSR(nn.Module):
13
+ """A recurrent network for video SR. Now only x4 is supported.
14
+
15
+ Args:
16
+ num_feat (int): Number of channels. Default: 64.
17
+ num_block (int): Number of residual blocks for each branch. Default: 15
18
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
19
+ """
20
+
21
+ def __init__(self, num_feat=64, num_block=15, spynet_path=None):
22
+ super().__init__()
23
+ self.num_feat = num_feat
24
+
25
+ # alignment
26
+ self.spynet = SpyNet(spynet_path)
27
+
28
+ # propagation
29
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
30
+ self.forward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
31
+
32
+ # reconstruction
33
+ self.fusion = nn.Conv2d(num_feat * 2, num_feat, 1, 1, 0, bias=True)
34
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
35
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
36
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
37
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
38
+
39
+ self.pixel_shuffle = nn.PixelShuffle(2)
40
+
41
+ # activation functions
42
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
43
+
44
+ def get_flow(self, x):
45
+ b, n, c, h, w = x.size()
46
+
47
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
48
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
49
+
50
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
51
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
52
+
53
+ return flows_forward, flows_backward
54
+
55
+ def forward(self, x):
56
+ """Forward function of BasicVSR.
57
+
58
+ Args:
59
+ x: Input frames with shape (b, n, c, h, w). n is the temporal dimension / number of frames.
60
+ """
61
+ flows_forward, flows_backward = self.get_flow(x)
62
+ b, n, _, h, w = x.size()
63
+
64
+ # backward branch
65
+ out_l = []
66
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
67
+ for i in range(n - 1, -1, -1):
68
+ x_i = x[:, i, :, :, :]
69
+ if i < n - 1:
70
+ flow = flows_backward[:, i, :, :, :]
71
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
72
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
73
+ feat_prop = self.backward_trunk(feat_prop)
74
+ out_l.insert(0, feat_prop)
75
+
76
+ # forward branch
77
+ feat_prop = torch.zeros_like(feat_prop)
78
+ for i in range(0, n):
79
+ x_i = x[:, i, :, :, :]
80
+ if i > 0:
81
+ flow = flows_forward[:, i - 1, :, :, :]
82
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
83
+
84
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
85
+ feat_prop = self.forward_trunk(feat_prop)
86
+
87
+ # upsample
88
+ out = torch.cat([out_l[i], feat_prop], dim=1)
89
+ out = self.lrelu(self.fusion(out))
90
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
91
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
92
+ out = self.lrelu(self.conv_hr(out))
93
+ out = self.conv_last(out)
94
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
95
+ out += base
96
+ out_l[i] = out
97
+
98
+ return torch.stack(out_l, dim=1)
99
+
100
+
101
+ class ConvResidualBlocks(nn.Module):
102
+ """Conv and residual block used in BasicVSR.
103
+
104
+ Args:
105
+ num_in_ch (int): Number of input channels. Default: 3.
106
+ num_out_ch (int): Number of output channels. Default: 64.
107
+ num_block (int): Number of residual blocks. Default: 15.
108
+ """
109
+
110
+ def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
111
+ super().__init__()
112
+ self.main = nn.Sequential(
113
+ nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
114
+ make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
115
+
116
+ def forward(self, fea):
117
+ return self.main(fea)
118
+
119
+
120
+ @ARCH_REGISTRY.register()
121
+ class IconVSR(nn.Module):
122
+ """IconVSR, proposed also in the BasicVSR paper.
123
+
124
+ Args:
125
+ num_feat (int): Number of channels. Default: 64.
126
+ num_block (int): Number of residual blocks for each branch. Default: 15.
127
+ keyframe_stride (int): Keyframe stride. Default: 5.
128
+ temporal_padding (int): Temporal padding. Default: 2.
129
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
130
+ edvr_path (str): Path to the pretrained EDVR model. Default: None.
131
+ """
132
+
133
+ def __init__(self,
134
+ num_feat=64,
135
+ num_block=15,
136
+ keyframe_stride=5,
137
+ temporal_padding=2,
138
+ spynet_path=None,
139
+ edvr_path=None):
140
+ super().__init__()
141
+
142
+ self.num_feat = num_feat
143
+ self.temporal_padding = temporal_padding
144
+ self.keyframe_stride = keyframe_stride
145
+
146
+ # keyframe_branch
147
+ self.edvr = EDVRFeatureExtractor(temporal_padding * 2 + 1, num_feat, edvr_path)
148
+ # alignment
149
+ self.spynet = SpyNet(spynet_path)
150
+
151
+ # propagation
152
+ self.backward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
153
+ self.backward_trunk = ConvResidualBlocks(num_feat + 3, num_feat, num_block)
154
+
155
+ self.forward_fusion = nn.Conv2d(2 * num_feat, num_feat, 3, 1, 1, bias=True)
156
+ self.forward_trunk = ConvResidualBlocks(2 * num_feat + 3, num_feat, num_block)
157
+
158
+ # reconstruction
159
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1, bias=True)
160
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1, bias=True)
161
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
162
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
163
+
164
+ self.pixel_shuffle = nn.PixelShuffle(2)
165
+
166
+ # activation functions
167
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
168
+
169
+ def pad_spatial(self, x):
170
+ """Apply padding spatially.
171
+
172
+ Since the PCD module in EDVR requires that the resolution is a multiple
173
+ of 4, we apply padding to the input LR images if their resolution is
174
+ not divisible by 4.
175
+
176
+ Args:
177
+ x (Tensor): Input LR sequence with shape (n, t, c, h, w).
178
+ Returns:
179
+ Tensor: Padded LR sequence with shape (n, t, c, h_pad, w_pad).
180
+ """
181
+ n, t, c, h, w = x.size()
182
+
183
+ pad_h = (4 - h % 4) % 4
184
+ pad_w = (4 - w % 4) % 4
185
+
186
+ # padding
187
+ x = x.view(-1, c, h, w)
188
+ x = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
189
+
190
+ return x.view(n, t, c, h + pad_h, w + pad_w)
191
+
192
+ def get_flow(self, x):
193
+ b, n, c, h, w = x.size()
194
+
195
+ x_1 = x[:, :-1, :, :, :].reshape(-1, c, h, w)
196
+ x_2 = x[:, 1:, :, :, :].reshape(-1, c, h, w)
197
+
198
+ flows_backward = self.spynet(x_1, x_2).view(b, n - 1, 2, h, w)
199
+ flows_forward = self.spynet(x_2, x_1).view(b, n - 1, 2, h, w)
200
+
201
+ return flows_forward, flows_backward
202
+
203
+ def get_keyframe_feature(self, x, keyframe_idx):
204
+ if self.temporal_padding == 2:
205
+ x = [x[:, [4, 3]], x, x[:, [-4, -5]]]
206
+ elif self.temporal_padding == 3:
207
+ x = [x[:, [6, 5, 4]], x, x[:, [-5, -6, -7]]]
208
+ x = torch.cat(x, dim=1)
209
+
210
+ num_frames = 2 * self.temporal_padding + 1
211
+ feats_keyframe = {}
212
+ for i in keyframe_idx:
213
+ feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
214
+ return feats_keyframe
215
+
216
+ def forward(self, x):
217
+ b, n, _, h_input, w_input = x.size()
218
+
219
+ x = self.pad_spatial(x)
220
+ h, w = x.shape[3:]
221
+
222
+ keyframe_idx = list(range(0, n, self.keyframe_stride))
223
+ if keyframe_idx[-1] != n - 1:
224
+ keyframe_idx.append(n - 1) # last frame is a keyframe
225
+
226
+ # compute flow and keyframe features
227
+ flows_forward, flows_backward = self.get_flow(x)
228
+ feats_keyframe = self.get_keyframe_feature(x, keyframe_idx)
229
+
230
+ # backward branch
231
+ out_l = []
232
+ feat_prop = x.new_zeros(b, self.num_feat, h, w)
233
+ for i in range(n - 1, -1, -1):
234
+ x_i = x[:, i, :, :, :]
235
+ if i < n - 1:
236
+ flow = flows_backward[:, i, :, :, :]
237
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
238
+ if i in keyframe_idx:
239
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
240
+ feat_prop = self.backward_fusion(feat_prop)
241
+ feat_prop = torch.cat([x_i, feat_prop], dim=1)
242
+ feat_prop = self.backward_trunk(feat_prop)
243
+ out_l.insert(0, feat_prop)
244
+
245
+ # forward branch
246
+ feat_prop = torch.zeros_like(feat_prop)
247
+ for i in range(0, n):
248
+ x_i = x[:, i, :, :, :]
249
+ if i > 0:
250
+ flow = flows_forward[:, i - 1, :, :, :]
251
+ feat_prop = flow_warp(feat_prop, flow.permute(0, 2, 3, 1))
252
+ if i in keyframe_idx:
253
+ feat_prop = torch.cat([feat_prop, feats_keyframe[i]], dim=1)
254
+ feat_prop = self.forward_fusion(feat_prop)
255
+
256
+ feat_prop = torch.cat([x_i, out_l[i], feat_prop], dim=1)
257
+ feat_prop = self.forward_trunk(feat_prop)
258
+
259
+ # upsample
260
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(feat_prop)))
261
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
262
+ out = self.lrelu(self.conv_hr(out))
263
+ out = self.conv_last(out)
264
+ base = F.interpolate(x_i, scale_factor=4, mode='bilinear', align_corners=False)
265
+ out += base
266
+ out_l[i] = out
267
+
268
+ return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
269
+
270
+
271
+ class EDVRFeatureExtractor(nn.Module):
272
+ """EDVR feature extractor used in IconVSR.
273
+
274
+ Args:
275
+ num_input_frame (int): Number of input frames.
276
+ num_feat (int): Number of feature channels
277
+ load_path (str): Path to the pretrained weights of EDVR. Default: None.
278
+ """
279
+
280
+ def __init__(self, num_input_frame, num_feat, load_path):
281
+
282
+ super(EDVRFeatureExtractor, self).__init__()
283
+
284
+ self.center_frame_idx = num_input_frame // 2
285
+
286
+ # extract pyramid features
287
+ self.conv_first = nn.Conv2d(3, num_feat, 3, 1, 1)
288
+ self.feature_extraction = make_layer(ResidualBlockNoBN, 5, num_feat=num_feat)
289
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
290
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
291
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
292
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
293
+
294
+ # pcd and tsa module
295
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=8)
296
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_input_frame, center_frame_idx=self.center_frame_idx)
297
+
298
+ # activation function
299
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
300
+
301
+ if load_path:
302
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
303
+
304
+ def forward(self, x):
305
+ b, n, c, h, w = x.size()
306
+
307
+ # extract features for each frame
308
+ # L1
309
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
310
+ feat_l1 = self.feature_extraction(feat_l1)
311
+ # L2
312
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
313
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
314
+ # L3
315
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
316
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
317
+
318
+ feat_l1 = feat_l1.view(b, n, -1, h, w)
319
+ feat_l2 = feat_l2.view(b, n, -1, h // 2, w // 2)
320
+ feat_l3 = feat_l3.view(b, n, -1, h // 4, w // 4)
321
+
322
+ # PCD alignment
323
+ ref_feat_l = [ # reference feature list
324
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
325
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
326
+ ]
327
+ aligned_feat = []
328
+ for i in range(n):
329
+ nbr_feat_l = [ # neighboring feature list
330
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
331
+ ]
332
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
333
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
334
+
335
+ # TSA fusion
336
+ return self.fusion(aligned_feat)
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/basicvsrpp_arch.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ import warnings
6
+
7
+ from r_basicsr.archs.arch_util import flow_warp
8
+ from r_basicsr.archs.basicvsr_arch import ConvResidualBlocks
9
+ from r_basicsr.archs.spynet_arch import SpyNet
10
+ from r_basicsr.ops.dcn import ModulatedDeformConvPack
11
+ from r_basicsr.utils.registry import ARCH_REGISTRY
12
+
13
+
14
+ @ARCH_REGISTRY.register()
15
+ class BasicVSRPlusPlus(nn.Module):
16
+ """BasicVSR++ network structure.
17
+ Support either x4 upsampling or same size output. Since DCN is used in this
18
+ model, it can only be used with CUDA enabled. If CUDA is not enabled,
19
+ feature alignment will be skipped. Besides, we adopt the official DCN
20
+ implementation and the version of torch need to be higher than 1.9.
21
+ Paper:
22
+ BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation
23
+ and Alignment
24
+ Args:
25
+ mid_channels (int, optional): Channel number of the intermediate
26
+ features. Default: 64.
27
+ num_blocks (int, optional): The number of residual blocks in each
28
+ propagation branch. Default: 7.
29
+ max_residue_magnitude (int): The maximum magnitude of the offset
30
+ residue (Eq. 6 in paper). Default: 10.
31
+ is_low_res_input (bool, optional): Whether the input is low-resolution
32
+ or not. If False, the output resolution is equal to the input
33
+ resolution. Default: True.
34
+ spynet_path (str): Path to the pretrained weights of SPyNet. Default: None.
35
+ cpu_cache_length (int, optional): When the length of sequence is larger
36
+ than this value, the intermediate features are sent to CPU. This
37
+ saves GPU memory, but slows down the inference speed. You can
38
+ increase this number if you have a GPU with large memory.
39
+ Default: 100.
40
+ """
41
+
42
+ def __init__(self,
43
+ mid_channels=64,
44
+ num_blocks=7,
45
+ max_residue_magnitude=10,
46
+ is_low_res_input=True,
47
+ spynet_path=None,
48
+ cpu_cache_length=100):
49
+
50
+ super().__init__()
51
+ self.mid_channels = mid_channels
52
+ self.is_low_res_input = is_low_res_input
53
+ self.cpu_cache_length = cpu_cache_length
54
+
55
+ # optical flow
56
+ self.spynet = SpyNet(spynet_path)
57
+
58
+ # feature extraction module
59
+ if is_low_res_input:
60
+ self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
61
+ else:
62
+ self.feat_extract = nn.Sequential(
63
+ nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
64
+ nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
65
+ ConvResidualBlocks(mid_channels, mid_channels, 5))
66
+
67
+ # propagation branches
68
+ self.deform_align = nn.ModuleDict()
69
+ self.backbone = nn.ModuleDict()
70
+ modules = ['backward_1', 'forward_1', 'backward_2', 'forward_2']
71
+ for i, module in enumerate(modules):
72
+ if torch.cuda.is_available():
73
+ self.deform_align[module] = SecondOrderDeformableAlignment(
74
+ 2 * mid_channels,
75
+ mid_channels,
76
+ 3,
77
+ padding=1,
78
+ deformable_groups=16,
79
+ max_residue_magnitude=max_residue_magnitude)
80
+ self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
81
+
82
+ # upsampling module
83
+ self.reconstruction = ConvResidualBlocks(5 * mid_channels, mid_channels, 5)
84
+
85
+ self.upconv1 = nn.Conv2d(mid_channels, mid_channels * 4, 3, 1, 1, bias=True)
86
+ self.upconv2 = nn.Conv2d(mid_channels, 64 * 4, 3, 1, 1, bias=True)
87
+
88
+ self.pixel_shuffle = nn.PixelShuffle(2)
89
+
90
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
91
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
92
+ self.img_upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
93
+
94
+ # activation function
95
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
96
+
97
+ # check if the sequence is augmented by flipping
98
+ self.is_mirror_extended = False
99
+
100
+ if len(self.deform_align) > 0:
101
+ self.is_with_alignment = True
102
+ else:
103
+ self.is_with_alignment = False
104
+ warnings.warn('Deformable alignment module is not added. '
105
+ 'Probably your CUDA is not configured correctly. DCN can only '
106
+ 'be used with CUDA enabled. Alignment is skipped now.')
107
+
108
+ def check_if_mirror_extended(self, lqs):
109
+ """Check whether the input is a mirror-extended sequence.
110
+ If mirror-extended, the i-th (i=0, ..., t-1) frame is equal to the
111
+ (t-1-i)-th frame.
112
+ Args:
113
+ lqs (tensor): Input low quality (LQ) sequence with
114
+ shape (n, t, c, h, w).
115
+ """
116
+
117
+ if lqs.size(1) % 2 == 0:
118
+ lqs_1, lqs_2 = torch.chunk(lqs, 2, dim=1)
119
+ if torch.norm(lqs_1 - lqs_2.flip(1)) == 0:
120
+ self.is_mirror_extended = True
121
+
122
+ def compute_flow(self, lqs):
123
+ """Compute optical flow using SPyNet for feature alignment.
124
+ Note that if the input is an mirror-extended sequence, 'flows_forward'
125
+ is not needed, since it is equal to 'flows_backward.flip(1)'.
126
+ Args:
127
+ lqs (tensor): Input low quality (LQ) sequence with
128
+ shape (n, t, c, h, w).
129
+ Return:
130
+ tuple(Tensor): Optical flow. 'flows_forward' corresponds to the
131
+ flows used for forward-time propagation (current to previous).
132
+ 'flows_backward' corresponds to the flows used for
133
+ backward-time propagation (current to next).
134
+ """
135
+
136
+ n, t, c, h, w = lqs.size()
137
+ lqs_1 = lqs[:, :-1, :, :, :].reshape(-1, c, h, w)
138
+ lqs_2 = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
139
+
140
+ flows_backward = self.spynet(lqs_1, lqs_2).view(n, t - 1, 2, h, w)
141
+
142
+ if self.is_mirror_extended: # flows_forward = flows_backward.flip(1)
143
+ flows_forward = flows_backward.flip(1)
144
+ else:
145
+ flows_forward = self.spynet(lqs_2, lqs_1).view(n, t - 1, 2, h, w)
146
+
147
+ if self.cpu_cache:
148
+ flows_backward = flows_backward.cpu()
149
+ flows_forward = flows_forward.cpu()
150
+
151
+ return flows_forward, flows_backward
152
+
153
+ def propagate(self, feats, flows, module_name):
154
+ """Propagate the latent features throughout the sequence.
155
+ Args:
156
+ feats dict(list[tensor]): Features from previous branches. Each
157
+ component is a list of tensors with shape (n, c, h, w).
158
+ flows (tensor): Optical flows with shape (n, t - 1, 2, h, w).
159
+ module_name (str): The name of the propgation branches. Can either
160
+ be 'backward_1', 'forward_1', 'backward_2', 'forward_2'.
161
+ Return:
162
+ dict(list[tensor]): A dictionary containing all the propagated
163
+ features. Each key in the dictionary corresponds to a
164
+ propagation branch, which is represented by a list of tensors.
165
+ """
166
+
167
+ n, t, _, h, w = flows.size()
168
+
169
+ frame_idx = range(0, t + 1)
170
+ flow_idx = range(-1, t)
171
+ mapping_idx = list(range(0, len(feats['spatial'])))
172
+ mapping_idx += mapping_idx[::-1]
173
+
174
+ if 'backward' in module_name:
175
+ frame_idx = frame_idx[::-1]
176
+ flow_idx = frame_idx
177
+
178
+ feat_prop = flows.new_zeros(n, self.mid_channels, h, w)
179
+ for i, idx in enumerate(frame_idx):
180
+ feat_current = feats['spatial'][mapping_idx[idx]]
181
+ if self.cpu_cache:
182
+ feat_current = feat_current.cuda()
183
+ feat_prop = feat_prop.cuda()
184
+ # second-order deformable alignment
185
+ if i > 0 and self.is_with_alignment:
186
+ flow_n1 = flows[:, flow_idx[i], :, :, :]
187
+ if self.cpu_cache:
188
+ flow_n1 = flow_n1.cuda()
189
+
190
+ cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
191
+
192
+ # initialize second-order features
193
+ feat_n2 = torch.zeros_like(feat_prop)
194
+ flow_n2 = torch.zeros_like(flow_n1)
195
+ cond_n2 = torch.zeros_like(cond_n1)
196
+
197
+ if i > 1: # second-order features
198
+ feat_n2 = feats[module_name][-2]
199
+ if self.cpu_cache:
200
+ feat_n2 = feat_n2.cuda()
201
+
202
+ flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
203
+ if self.cpu_cache:
204
+ flow_n2 = flow_n2.cuda()
205
+
206
+ flow_n2 = flow_n1 + flow_warp(flow_n2, flow_n1.permute(0, 2, 3, 1))
207
+ cond_n2 = flow_warp(feat_n2, flow_n2.permute(0, 2, 3, 1))
208
+
209
+ # flow-guided deformable convolution
210
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
211
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
212
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_n1, flow_n2)
213
+
214
+ # concatenate and residual blocks
215
+ feat = [feat_current] + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] + [feat_prop]
216
+ if self.cpu_cache:
217
+ feat = [f.cuda() for f in feat]
218
+
219
+ feat = torch.cat(feat, dim=1)
220
+ feat_prop = feat_prop + self.backbone[module_name](feat)
221
+ feats[module_name].append(feat_prop)
222
+
223
+ if self.cpu_cache:
224
+ feats[module_name][-1] = feats[module_name][-1].cpu()
225
+ torch.cuda.empty_cache()
226
+
227
+ if 'backward' in module_name:
228
+ feats[module_name] = feats[module_name][::-1]
229
+
230
+ return feats
231
+
232
+ def upsample(self, lqs, feats):
233
+ """Compute the output image given the features.
234
+ Args:
235
+ lqs (tensor): Input low quality (LQ) sequence with
236
+ shape (n, t, c, h, w).
237
+ feats (dict): The features from the propgation branches.
238
+ Returns:
239
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
240
+ """
241
+
242
+ outputs = []
243
+ num_outputs = len(feats['spatial'])
244
+
245
+ mapping_idx = list(range(0, num_outputs))
246
+ mapping_idx += mapping_idx[::-1]
247
+
248
+ for i in range(0, lqs.size(1)):
249
+ hr = [feats[k].pop(0) for k in feats if k != 'spatial']
250
+ hr.insert(0, feats['spatial'][mapping_idx[i]])
251
+ hr = torch.cat(hr, dim=1)
252
+ if self.cpu_cache:
253
+ hr = hr.cuda()
254
+
255
+ hr = self.reconstruction(hr)
256
+ hr = self.lrelu(self.pixel_shuffle(self.upconv1(hr)))
257
+ hr = self.lrelu(self.pixel_shuffle(self.upconv2(hr)))
258
+ hr = self.lrelu(self.conv_hr(hr))
259
+ hr = self.conv_last(hr)
260
+ if self.is_low_res_input:
261
+ hr += self.img_upsample(lqs[:, i, :, :, :])
262
+ else:
263
+ hr += lqs[:, i, :, :, :]
264
+
265
+ if self.cpu_cache:
266
+ hr = hr.cpu()
267
+ torch.cuda.empty_cache()
268
+
269
+ outputs.append(hr)
270
+
271
+ return torch.stack(outputs, dim=1)
272
+
273
+ def forward(self, lqs):
274
+ """Forward function for BasicVSR++.
275
+ Args:
276
+ lqs (tensor): Input low quality (LQ) sequence with
277
+ shape (n, t, c, h, w).
278
+ Returns:
279
+ Tensor: Output HR sequence with shape (n, t, c, 4h, 4w).
280
+ """
281
+
282
+ n, t, c, h, w = lqs.size()
283
+
284
+ # whether to cache the features in CPU
285
+ self.cpu_cache = True if t > self.cpu_cache_length else False
286
+
287
+ if self.is_low_res_input:
288
+ lqs_downsample = lqs.clone()
289
+ else:
290
+ lqs_downsample = F.interpolate(
291
+ lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
292
+
293
+ # check whether the input is an extended sequence
294
+ self.check_if_mirror_extended(lqs)
295
+
296
+ feats = {}
297
+ # compute spatial features
298
+ if self.cpu_cache:
299
+ feats['spatial'] = []
300
+ for i in range(0, t):
301
+ feat = self.feat_extract(lqs[:, i, :, :, :]).cpu()
302
+ feats['spatial'].append(feat)
303
+ torch.cuda.empty_cache()
304
+ else:
305
+ feats_ = self.feat_extract(lqs.view(-1, c, h, w))
306
+ h, w = feats_.shape[2:]
307
+ feats_ = feats_.view(n, t, -1, h, w)
308
+ feats['spatial'] = [feats_[:, i, :, :, :] for i in range(0, t)]
309
+
310
+ # compute optical flow using the low-res inputs
311
+ assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
312
+ 'The height and width of low-res inputs must be at least 64, '
313
+ f'but got {h} and {w}.')
314
+ flows_forward, flows_backward = self.compute_flow(lqs_downsample)
315
+
316
+ # feature propgation
317
+ for iter_ in [1, 2]:
318
+ for direction in ['backward', 'forward']:
319
+ module = f'{direction}_{iter_}'
320
+
321
+ feats[module] = []
322
+
323
+ if direction == 'backward':
324
+ flows = flows_backward
325
+ elif flows_forward is not None:
326
+ flows = flows_forward
327
+ else:
328
+ flows = flows_backward.flip(1)
329
+
330
+ feats = self.propagate(feats, flows, module)
331
+ if self.cpu_cache:
332
+ del flows
333
+ torch.cuda.empty_cache()
334
+
335
+ return self.upsample(lqs, feats)
336
+
337
+
338
+ class SecondOrderDeformableAlignment(ModulatedDeformConvPack):
339
+ """Second-order deformable alignment module.
340
+ Args:
341
+ in_channels (int): Same as nn.Conv2d.
342
+ out_channels (int): Same as nn.Conv2d.
343
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
344
+ stride (int or tuple[int]): Same as nn.Conv2d.
345
+ padding (int or tuple[int]): Same as nn.Conv2d.
346
+ dilation (int or tuple[int]): Same as nn.Conv2d.
347
+ groups (int): Same as nn.Conv2d.
348
+ bias (bool or str): If specified as `auto`, it will be decided by the
349
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
350
+ False.
351
+ max_residue_magnitude (int): The maximum magnitude of the offset
352
+ residue (Eq. 6 in paper). Default: 10.
353
+ """
354
+
355
+ def __init__(self, *args, **kwargs):
356
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
357
+
358
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
359
+
360
+ self.conv_offset = nn.Sequential(
361
+ nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
362
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
363
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
364
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
365
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
366
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
367
+ nn.Conv2d(self.out_channels, 27 * self.deformable_groups, 3, 1, 1),
368
+ )
369
+
370
+ self.init_offset()
371
+
372
+ def init_offset(self):
373
+
374
+ def _constant_init(module, val, bias=0):
375
+ if hasattr(module, 'weight') and module.weight is not None:
376
+ nn.init.constant_(module.weight, val)
377
+ if hasattr(module, 'bias') and module.bias is not None:
378
+ nn.init.constant_(module.bias, bias)
379
+
380
+ _constant_init(self.conv_offset[-1], val=0, bias=0)
381
+
382
+ def forward(self, x, extra_feat, flow_1, flow_2):
383
+ extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
384
+ out = self.conv_offset(extra_feat)
385
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
386
+
387
+ # offset
388
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
389
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
390
+ offset_1 = offset_1 + flow_1.flip(1).repeat(1, offset_1.size(1) // 2, 1, 1)
391
+ offset_2 = offset_2 + flow_2.flip(1).repeat(1, offset_2.size(1) // 2, 1, 1)
392
+ offset = torch.cat([offset_1, offset_2], dim=1)
393
+
394
+ # mask
395
+ mask = torch.sigmoid(mask)
396
+
397
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
398
+ self.dilation, mask)
399
+
400
+
401
+ # if __name__ == '__main__':
402
+ # spynet_path = 'experiments/pretrained_models/flownet/spynet_sintel_final-3d2a1287.pth'
403
+ # model = BasicVSRPlusPlus(spynet_path=spynet_path).cuda()
404
+ # input = torch.rand(1, 2, 3, 64, 64).cuda()
405
+ # output = model(input)
406
+ # print('===================')
407
+ # print(output.shape)
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_arch.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils.spectral_norm import spectral_norm
6
+
7
+ from r_basicsr.utils.registry import ARCH_REGISTRY
8
+ from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
9
+ from .vgg_arch import VGGFeatureExtractor
10
+
11
+
12
+ class SFTUpBlock(nn.Module):
13
+ """Spatial feature transform (SFT) with upsampling block.
14
+
15
+ Args:
16
+ in_channel (int): Number of input channels.
17
+ out_channel (int): Number of output channels.
18
+ kernel_size (int): Kernel size in convolutions. Default: 3.
19
+ padding (int): Padding in convolutions. Default: 1.
20
+ """
21
+
22
+ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
23
+ super(SFTUpBlock, self).__init__()
24
+ self.conv1 = nn.Sequential(
25
+ Blur(in_channel),
26
+ spectral_norm(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding)),
27
+ nn.LeakyReLU(0.04, True),
28
+ # The official codes use two LeakyReLU here, so 0.04 for equivalent
29
+ )
30
+ self.convup = nn.Sequential(
31
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
32
+ spectral_norm(nn.Conv2d(out_channel, out_channel, kernel_size, padding=padding)),
33
+ nn.LeakyReLU(0.2, True),
34
+ )
35
+
36
+ # for SFT scale and shift
37
+ self.scale_block = nn.Sequential(
38
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
39
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
40
+ self.shift_block = nn.Sequential(
41
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
42
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
43
+ # The official codes use sigmoid for shift block, do not know why
44
+
45
+ def forward(self, x, updated_feat):
46
+ out = self.conv1(x)
47
+ # SFT
48
+ scale = self.scale_block(updated_feat)
49
+ shift = self.shift_block(updated_feat)
50
+ out = out * scale + shift
51
+ # upsample
52
+ out = self.convup(out)
53
+ return out
54
+
55
+
56
+ @ARCH_REGISTRY.register()
57
+ class DFDNet(nn.Module):
58
+ """DFDNet: Deep Face Dictionary Network.
59
+
60
+ It only processes faces with 512x512 size.
61
+
62
+ Args:
63
+ num_feat (int): Number of feature channels.
64
+ dict_path (str): Path to the facial component dictionary.
65
+ """
66
+
67
+ def __init__(self, num_feat, dict_path):
68
+ super().__init__()
69
+ self.parts = ['left_eye', 'right_eye', 'nose', 'mouth']
70
+ # part_sizes: [80, 80, 50, 110]
71
+ channel_sizes = [128, 256, 512, 512]
72
+ self.feature_sizes = np.array([256, 128, 64, 32])
73
+ self.vgg_layers = ['relu2_2', 'relu3_4', 'relu4_4', 'conv5_4']
74
+ self.flag_dict_device = False
75
+
76
+ # dict
77
+ self.dict = torch.load(dict_path)
78
+
79
+ # vgg face extractor
80
+ self.vgg_extractor = VGGFeatureExtractor(
81
+ layer_name_list=self.vgg_layers,
82
+ vgg_type='vgg19',
83
+ use_input_norm=True,
84
+ range_norm=True,
85
+ requires_grad=False)
86
+
87
+ # attention block for fusing dictionary features and input features
88
+ self.attn_blocks = nn.ModuleDict()
89
+ for idx, feat_size in enumerate(self.feature_sizes):
90
+ for name in self.parts:
91
+ self.attn_blocks[f'{name}_{feat_size}'] = AttentionBlock(channel_sizes[idx])
92
+
93
+ # multi scale dilation block
94
+ self.multi_scale_dilation = MSDilationBlock(num_feat * 8, dilation=[4, 3, 2, 1])
95
+
96
+ # upsampling and reconstruction
97
+ self.upsample0 = SFTUpBlock(num_feat * 8, num_feat * 8)
98
+ self.upsample1 = SFTUpBlock(num_feat * 8, num_feat * 4)
99
+ self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
100
+ self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
101
+ self.upsample4 = nn.Sequential(
102
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
103
+ UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
104
+
105
+ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
106
+ """swap the features from the dictionary."""
107
+ # get the original vgg features
108
+ part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
109
+ # resize original vgg features
110
+ part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
111
+ # use adaptive instance normalization to adjust color and illuminations
112
+ dict_feat = adaptive_instance_normalization(dict_feat, part_resize_feat)
113
+ # get similarity scores
114
+ similarity_score = F.conv2d(part_resize_feat, dict_feat)
115
+ similarity_score = F.softmax(similarity_score.view(-1), dim=0)
116
+ # select the most similar features in the dict (after norm)
117
+ select_idx = torch.argmax(similarity_score)
118
+ swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
119
+ # attention
120
+ attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
121
+ attn_feat = attn * swap_feat
122
+ # update features
123
+ updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
124
+ return updated_feat
125
+
126
+ def put_dict_to_device(self, x):
127
+ if self.flag_dict_device is False:
128
+ for k, v in self.dict.items():
129
+ for kk, vv in v.items():
130
+ self.dict[k][kk] = vv.to(x)
131
+ self.flag_dict_device = True
132
+
133
+ def forward(self, x, part_locations):
134
+ """
135
+ Now only support testing with batch size = 0.
136
+
137
+ Args:
138
+ x (Tensor): Input faces with shape (b, c, 512, 512).
139
+ part_locations (list[Tensor]): Part locations.
140
+ """
141
+ self.put_dict_to_device(x)
142
+ # extract vggface features
143
+ vgg_features = self.vgg_extractor(x)
144
+ # update vggface features using the dictionary for each part
145
+ updated_vgg_features = []
146
+ batch = 0 # only supports testing with batch size = 0
147
+ for vgg_layer, f_size in zip(self.vgg_layers, self.feature_sizes):
148
+ dict_features = self.dict[f'{f_size}']
149
+ vgg_feat = vgg_features[vgg_layer]
150
+ updated_feat = vgg_feat.clone()
151
+
152
+ # swap features from dictionary
153
+ for part_idx, part_name in enumerate(self.parts):
154
+ location = (part_locations[part_idx][batch] // (512 / f_size)).int()
155
+ updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
156
+ f_size)
157
+
158
+ updated_vgg_features.append(updated_feat)
159
+
160
+ vgg_feat_dilation = self.multi_scale_dilation(vgg_features['conv5_4'])
161
+ # use updated vgg features to modulate the upsampled features with
162
+ # SFT (Spatial Feature Transform) scaling and shifting manner.
163
+ upsampled_feat = self.upsample0(vgg_feat_dilation, updated_vgg_features[3])
164
+ upsampled_feat = self.upsample1(upsampled_feat, updated_vgg_features[2])
165
+ upsampled_feat = self.upsample2(upsampled_feat, updated_vgg_features[1])
166
+ upsampled_feat = self.upsample3(upsampled_feat, updated_vgg_features[0])
167
+ out = self.upsample4(upsampled_feat)
168
+
169
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/dfdnet_util.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+ from torch.nn.utils.spectral_norm import spectral_norm
6
+
7
+
8
+ class BlurFunctionBackward(Function):
9
+
10
+ @staticmethod
11
+ def forward(ctx, grad_output, kernel, kernel_flip):
12
+ ctx.save_for_backward(kernel, kernel_flip)
13
+ grad_input = F.conv2d(grad_output, kernel_flip, padding=1, groups=grad_output.shape[1])
14
+ return grad_input
15
+
16
+ @staticmethod
17
+ def backward(ctx, gradgrad_output):
18
+ kernel, _ = ctx.saved_tensors
19
+ grad_input = F.conv2d(gradgrad_output, kernel, padding=1, groups=gradgrad_output.shape[1])
20
+ return grad_input, None, None
21
+
22
+
23
+ class BlurFunction(Function):
24
+
25
+ @staticmethod
26
+ def forward(ctx, x, kernel, kernel_flip):
27
+ ctx.save_for_backward(kernel, kernel_flip)
28
+ output = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
29
+ return output
30
+
31
+ @staticmethod
32
+ def backward(ctx, grad_output):
33
+ kernel, kernel_flip = ctx.saved_tensors
34
+ grad_input = BlurFunctionBackward.apply(grad_output, kernel, kernel_flip)
35
+ return grad_input, None, None
36
+
37
+
38
+ blur = BlurFunction.apply
39
+
40
+
41
+ class Blur(nn.Module):
42
+
43
+ def __init__(self, channel):
44
+ super().__init__()
45
+ kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
46
+ kernel = kernel.view(1, 1, 3, 3)
47
+ kernel = kernel / kernel.sum()
48
+ kernel_flip = torch.flip(kernel, [2, 3])
49
+
50
+ self.kernel = kernel.repeat(channel, 1, 1, 1)
51
+ self.kernel_flip = kernel_flip.repeat(channel, 1, 1, 1)
52
+
53
+ def forward(self, x):
54
+ return blur(x, self.kernel.type_as(x), self.kernel_flip.type_as(x))
55
+
56
+
57
+ def calc_mean_std(feat, eps=1e-5):
58
+ """Calculate mean and std for adaptive_instance_normalization.
59
+
60
+ Args:
61
+ feat (Tensor): 4D tensor.
62
+ eps (float): A small value added to the variance to avoid
63
+ divide-by-zero. Default: 1e-5.
64
+ """
65
+ size = feat.size()
66
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
67
+ n, c = size[:2]
68
+ feat_var = feat.view(n, c, -1).var(dim=2) + eps
69
+ feat_std = feat_var.sqrt().view(n, c, 1, 1)
70
+ feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
71
+ return feat_mean, feat_std
72
+
73
+
74
+ def adaptive_instance_normalization(content_feat, style_feat):
75
+ """Adaptive instance normalization.
76
+
77
+ Adjust the reference features to have the similar color and illuminations
78
+ as those in the degradate features.
79
+
80
+ Args:
81
+ content_feat (Tensor): The reference feature.
82
+ style_feat (Tensor): The degradate features.
83
+ """
84
+ size = content_feat.size()
85
+ style_mean, style_std = calc_mean_std(style_feat)
86
+ content_mean, content_std = calc_mean_std(content_feat)
87
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
88
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
89
+
90
+
91
+ def AttentionBlock(in_channel):
92
+ return nn.Sequential(
93
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
94
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
95
+
96
+
97
+ def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
98
+ """Conv block used in MSDilationBlock."""
99
+
100
+ return nn.Sequential(
101
+ spectral_norm(
102
+ nn.Conv2d(
103
+ in_channels,
104
+ out_channels,
105
+ kernel_size=kernel_size,
106
+ stride=stride,
107
+ dilation=dilation,
108
+ padding=((kernel_size - 1) // 2) * dilation,
109
+ bias=bias)),
110
+ nn.LeakyReLU(0.2),
111
+ spectral_norm(
112
+ nn.Conv2d(
113
+ out_channels,
114
+ out_channels,
115
+ kernel_size=kernel_size,
116
+ stride=stride,
117
+ dilation=dilation,
118
+ padding=((kernel_size - 1) // 2) * dilation,
119
+ bias=bias)),
120
+ )
121
+
122
+
123
+ class MSDilationBlock(nn.Module):
124
+ """Multi-scale dilation block."""
125
+
126
+ def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True):
127
+ super(MSDilationBlock, self).__init__()
128
+
129
+ self.conv_blocks = nn.ModuleList()
130
+ for i in range(4):
131
+ self.conv_blocks.append(conv_block(in_channels, in_channels, kernel_size, dilation=dilation[i], bias=bias))
132
+ self.conv_fusion = spectral_norm(
133
+ nn.Conv2d(
134
+ in_channels * 4,
135
+ in_channels,
136
+ kernel_size=kernel_size,
137
+ stride=1,
138
+ padding=(kernel_size - 1) // 2,
139
+ bias=bias))
140
+
141
+ def forward(self, x):
142
+ out = []
143
+ for i in range(4):
144
+ out.append(self.conv_blocks[i](x))
145
+ out = torch.cat(out, 1)
146
+ out = self.conv_fusion(out) + x
147
+ return out
148
+
149
+
150
+ class UpResBlock(nn.Module):
151
+
152
+ def __init__(self, in_channel):
153
+ super(UpResBlock, self).__init__()
154
+ self.body = nn.Sequential(
155
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
156
+ nn.LeakyReLU(0.2, True),
157
+ nn.Conv2d(in_channel, in_channel, 3, 1, 1),
158
+ )
159
+
160
+ def forward(self, x):
161
+ out = x + self.body(x)
162
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/discriminator_arch.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+ from torch.nn.utils import spectral_norm
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ @ARCH_REGISTRY.register()
9
+ class VGGStyleDiscriminator(nn.Module):
10
+ """VGG style discriminator with input size 128 x 128 or 256 x 256.
11
+
12
+ It is used to train SRGAN, ESRGAN, and VideoGAN.
13
+
14
+ Args:
15
+ num_in_ch (int): Channel number of inputs. Default: 3.
16
+ num_feat (int): Channel number of base intermediate features.Default: 64.
17
+ """
18
+
19
+ def __init__(self, num_in_ch, num_feat, input_size=128):
20
+ super(VGGStyleDiscriminator, self).__init__()
21
+ self.input_size = input_size
22
+ assert self.input_size == 128 or self.input_size == 256, (
23
+ f'input size must be 128 or 256, but received {input_size}')
24
+
25
+ self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
26
+ self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
27
+ self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True)
28
+
29
+ self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
30
+ self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True)
31
+ self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)
32
+ self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True)
33
+
34
+ self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
35
+ self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True)
36
+ self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)
37
+ self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True)
38
+
39
+ self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
40
+ self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
41
+ self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
42
+ self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
43
+
44
+ self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
45
+ self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
46
+ self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
47
+ self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
48
+
49
+ if self.input_size == 256:
50
+ self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False)
51
+ self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True)
52
+ self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)
53
+ self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True)
54
+
55
+ self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100)
56
+ self.linear2 = nn.Linear(100, 1)
57
+
58
+ # activation function
59
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
60
+
61
+ def forward(self, x):
62
+ assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
63
+
64
+ feat = self.lrelu(self.conv0_0(x))
65
+ feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
66
+
67
+ feat = self.lrelu(self.bn1_0(self.conv1_0(feat)))
68
+ feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4
69
+
70
+ feat = self.lrelu(self.bn2_0(self.conv2_0(feat)))
71
+ feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8
72
+
73
+ feat = self.lrelu(self.bn3_0(self.conv3_0(feat)))
74
+ feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16
75
+
76
+ feat = self.lrelu(self.bn4_0(self.conv4_0(feat)))
77
+ feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32
78
+
79
+ if self.input_size == 256:
80
+ feat = self.lrelu(self.bn5_0(self.conv5_0(feat)))
81
+ feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64
82
+
83
+ # spatial size: (4, 4)
84
+ feat = feat.view(feat.size(0), -1)
85
+ feat = self.lrelu(self.linear1(feat))
86
+ out = self.linear2(feat)
87
+ return out
88
+
89
+
90
+ @ARCH_REGISTRY.register(suffix='basicsr')
91
+ class UNetDiscriminatorSN(nn.Module):
92
+ """Defines a U-Net discriminator with spectral normalization (SN)
93
+
94
+ It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
95
+
96
+ Arg:
97
+ num_in_ch (int): Channel number of inputs. Default: 3.
98
+ num_feat (int): Channel number of base intermediate features. Default: 64.
99
+ skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
100
+ """
101
+
102
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
103
+ super(UNetDiscriminatorSN, self).__init__()
104
+ self.skip_connection = skip_connection
105
+ norm = spectral_norm
106
+ # the first convolution
107
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
108
+ # downsample
109
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
110
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
111
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
112
+ # upsample
113
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
114
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
115
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
116
+ # extra convolutions
117
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
118
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
119
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
120
+
121
+ def forward(self, x):
122
+ # downsample
123
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
124
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
125
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
126
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
127
+
128
+ # upsample
129
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
130
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
131
+
132
+ if self.skip_connection:
133
+ x4 = x4 + x2
134
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
135
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
136
+
137
+ if self.skip_connection:
138
+ x5 = x5 + x1
139
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
140
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
141
+
142
+ if self.skip_connection:
143
+ x6 = x6 + x0
144
+
145
+ # extra convolutions
146
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
147
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
148
+ out = self.conv9(out)
149
+
150
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/duf_arch.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from r_basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+
9
+ class DenseBlocksTemporalReduce(nn.Module):
10
+ """A concatenation of 3 dense blocks with reduction in temporal dimension.
11
+
12
+ Note that the output temporal dimension is 6 fewer the input temporal dimension, since there are 3 blocks.
13
+
14
+ Args:
15
+ num_feat (int): Number of channels in the blocks. Default: 64.
16
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32
17
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
18
+ Set to false if you want to train from scratch. Default: False.
19
+ """
20
+
21
+ def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
22
+ super(DenseBlocksTemporalReduce, self).__init__()
23
+ if adapt_official_weights:
24
+ eps = 1e-3
25
+ momentum = 1e-3
26
+ else: # pytorch default values
27
+ eps = 1e-05
28
+ momentum = 0.1
29
+
30
+ self.temporal_reduce1 = nn.Sequential(
31
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
32
+ nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
33
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
34
+ nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
35
+
36
+ self.temporal_reduce2 = nn.Sequential(
37
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
38
+ nn.Conv3d(
39
+ num_feat + num_grow_ch,
40
+ num_feat + num_grow_ch, (1, 1, 1),
41
+ stride=(1, 1, 1),
42
+ padding=(0, 0, 0),
43
+ bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
44
+ nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
45
+
46
+ self.temporal_reduce3 = nn.Sequential(
47
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
48
+ nn.Conv3d(
49
+ num_feat + 2 * num_grow_ch,
50
+ num_feat + 2 * num_grow_ch, (1, 1, 1),
51
+ stride=(1, 1, 1),
52
+ padding=(0, 0, 0),
53
+ bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
54
+ nn.ReLU(inplace=True),
55
+ nn.Conv3d(
56
+ num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
57
+
58
+ def forward(self, x):
59
+ """
60
+ Args:
61
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
62
+
63
+ Returns:
64
+ Tensor: Output with shape (b, num_feat + num_grow_ch * 3, 1, h, w).
65
+ """
66
+ x1 = self.temporal_reduce1(x)
67
+ x1 = torch.cat((x[:, :, 1:-1, :, :], x1), 1)
68
+
69
+ x2 = self.temporal_reduce2(x1)
70
+ x2 = torch.cat((x1[:, :, 1:-1, :, :], x2), 1)
71
+
72
+ x3 = self.temporal_reduce3(x2)
73
+ x3 = torch.cat((x2[:, :, 1:-1, :, :], x3), 1)
74
+
75
+ return x3
76
+
77
+
78
+ class DenseBlocks(nn.Module):
79
+ """ A concatenation of N dense blocks.
80
+
81
+ Args:
82
+ num_feat (int): Number of channels in the blocks. Default: 64.
83
+ num_grow_ch (int): Growing factor of the dense blocks. Default: 32.
84
+ num_block (int): Number of dense blocks. The values are:
85
+ DUF-S (16 layers): 3
86
+ DUF-M (18 layers): 9
87
+ DUF-L (52 layers): 21
88
+ adapt_official_weights (bool): Whether to adapt the weights translated from the official implementation.
89
+ Set to false if you want to train from scratch. Default: False.
90
+ """
91
+
92
+ def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weights=False):
93
+ super(DenseBlocks, self).__init__()
94
+ if adapt_official_weights:
95
+ eps = 1e-3
96
+ momentum = 1e-3
97
+ else: # pytorch default values
98
+ eps = 1e-05
99
+ momentum = 0.1
100
+
101
+ self.dense_blocks = nn.ModuleList()
102
+ for i in range(0, num_block):
103
+ self.dense_blocks.append(
104
+ nn.Sequential(
105
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
106
+ nn.Conv3d(
107
+ num_feat + i * num_grow_ch,
108
+ num_feat + i * num_grow_ch, (1, 1, 1),
109
+ stride=(1, 1, 1),
110
+ padding=(0, 0, 0),
111
+ bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
112
+ nn.ReLU(inplace=True),
113
+ nn.Conv3d(
114
+ num_feat + i * num_grow_ch,
115
+ num_grow_ch, (3, 3, 3),
116
+ stride=(1, 1, 1),
117
+ padding=(1, 1, 1),
118
+ bias=True)))
119
+
120
+ def forward(self, x):
121
+ """
122
+ Args:
123
+ x (Tensor): Input tensor with shape (b, num_feat, t, h, w).
124
+
125
+ Returns:
126
+ Tensor: Output with shape (b, num_feat + num_block * num_grow_ch, t, h, w).
127
+ """
128
+ for i in range(0, len(self.dense_blocks)):
129
+ y = self.dense_blocks[i](x)
130
+ x = torch.cat((x, y), 1)
131
+ return x
132
+
133
+
134
+ class DynamicUpsamplingFilter(nn.Module):
135
+ """Dynamic upsampling filter used in DUF.
136
+
137
+ Ref: https://github.com/yhjo09/VSR-DUF.
138
+ It only supports input with 3 channels. And it applies the same filters to 3 channels.
139
+
140
+ Args:
141
+ filter_size (tuple): Filter size of generated filters. The shape is (kh, kw). Default: (5, 5).
142
+ """
143
+
144
+ def __init__(self, filter_size=(5, 5)):
145
+ super(DynamicUpsamplingFilter, self).__init__()
146
+ if not isinstance(filter_size, tuple):
147
+ raise TypeError(f'The type of filter_size must be tuple, but got type{filter_size}')
148
+ if len(filter_size) != 2:
149
+ raise ValueError(f'The length of filter size must be 2, but got {len(filter_size)}.')
150
+ # generate a local expansion filter, similar to im2col
151
+ self.filter_size = filter_size
152
+ filter_prod = np.prod(filter_size)
153
+ expansion_filter = torch.eye(int(filter_prod)).view(filter_prod, 1, *filter_size) # (kh*kw, 1, kh, kw)
154
+ self.expansion_filter = expansion_filter.repeat(3, 1, 1, 1) # repeat for all the 3 channels
155
+
156
+ def forward(self, x, filters):
157
+ """Forward function for DynamicUpsamplingFilter.
158
+
159
+ Args:
160
+ x (Tensor): Input image with 3 channels. The shape is (n, 3, h, w).
161
+ filters (Tensor): Generated dynamic filters.
162
+ The shape is (n, filter_prod, upsampling_square, h, w).
163
+ filter_prod: prod of filter kernel size, e.g., 1*5*5=25.
164
+ upsampling_square: similar to pixel shuffle,
165
+ upsampling_square = upsampling * upsampling
166
+ e.g., for x 4 upsampling, upsampling_square= 4*4 = 16
167
+
168
+ Returns:
169
+ Tensor: Filtered image with shape (n, 3*upsampling_square, h, w)
170
+ """
171
+ n, filter_prod, upsampling_square, h, w = filters.size()
172
+ kh, kw = self.filter_size
173
+ expanded_input = F.conv2d(
174
+ x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
175
+ expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
176
+ 2) # (n, h, w, 3, filter_prod)
177
+ filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
178
+ out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
179
+ return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
180
+
181
+
182
+ @ARCH_REGISTRY.register()
183
+ class DUF(nn.Module):
184
+ """Network architecture for DUF
185
+
186
+ Paper: Jo et.al. Deep Video Super-Resolution Network Using Dynamic
187
+ Upsampling Filters Without Explicit Motion Compensation, CVPR, 2018
188
+ Code reference:
189
+ https://github.com/yhjo09/VSR-DUF
190
+ For all the models below, 'adapt_official_weights' is only necessary when
191
+ loading the weights converted from the official TensorFlow weights.
192
+ Please set it to False if you are training the model from scratch.
193
+
194
+ There are three models with different model size: DUF16Layers, DUF28Layers,
195
+ and DUF52Layers. This class is the base class for these models.
196
+
197
+ Args:
198
+ scale (int): The upsampling factor. Default: 4.
199
+ num_layer (int): The number of layers. Default: 52.
200
+ adapt_official_weights_weights (bool): Whether to adapt the weights
201
+ translated from the official implementation. Set to false if you
202
+ want to train from scratch. Default: False.
203
+ """
204
+
205
+ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
206
+ super(DUF, self).__init__()
207
+ self.scale = scale
208
+ if adapt_official_weights:
209
+ eps = 1e-3
210
+ momentum = 1e-3
211
+ else: # pytorch default values
212
+ eps = 1e-05
213
+ momentum = 0.1
214
+
215
+ self.conv3d1 = nn.Conv3d(3, 64, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
216
+ self.dynamic_filter = DynamicUpsamplingFilter((5, 5))
217
+
218
+ if num_layer == 16:
219
+ num_block = 3
220
+ num_grow_ch = 32
221
+ elif num_layer == 28:
222
+ num_block = 9
223
+ num_grow_ch = 16
224
+ elif num_layer == 52:
225
+ num_block = 21
226
+ num_grow_ch = 16
227
+ else:
228
+ raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
229
+
230
+ self.dense_block1 = DenseBlocks(
231
+ num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
232
+ adapt_official_weights=adapt_official_weights) # T = 7
233
+ self.dense_block2 = DenseBlocksTemporalReduce(
234
+ 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
235
+ channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
236
+ self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
237
+ self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
238
+
239
+ self.conv3d_r1 = nn.Conv3d(256, 256, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
240
+ self.conv3d_r2 = nn.Conv3d(256, 3 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
241
+
242
+ self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
243
+ self.conv3d_f2 = nn.Conv3d(
244
+ 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
245
+
246
+ def forward(self, x):
247
+ """
248
+ Args:
249
+ x (Tensor): Input with shape (b, 7, c, h, w)
250
+
251
+ Returns:
252
+ Tensor: Output with shape (b, c, h * scale, w * scale)
253
+ """
254
+ num_batches, num_imgs, _, h, w = x.size()
255
+
256
+ x = x.permute(0, 2, 1, 3, 4) # (b, c, 7, h, w) for Conv3D
257
+ x_center = x[:, :, num_imgs // 2, :, :]
258
+
259
+ x = self.conv3d1(x)
260
+ x = self.dense_block1(x)
261
+ x = self.dense_block2(x)
262
+ x = F.relu(self.bn3d2(x), inplace=True)
263
+ x = F.relu(self.conv3d2(x), inplace=True)
264
+
265
+ # residual image
266
+ res = self.conv3d_r2(F.relu(self.conv3d_r1(x), inplace=True))
267
+
268
+ # filter
269
+ filter_ = self.conv3d_f2(F.relu(self.conv3d_f1(x), inplace=True))
270
+ filter_ = F.softmax(filter_.view(num_batches, 25, self.scale**2, h, w), dim=1)
271
+
272
+ # dynamic filter
273
+ out = self.dynamic_filter(x_center, filter_)
274
+ out += res.squeeze_(2)
275
+ out = F.pixel_shuffle(out, self.scale)
276
+
277
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ecbsr_arch.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ class SeqConv3x3(nn.Module):
9
+ """The re-parameterizable block used in the ECBSR architecture.
10
+
11
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
12
+ Ref git repo: https://github.com/xindongzhang/ECBSR
13
+
14
+ Args:
15
+ seq_type (str): Sequence type, option: conv1x1-conv3x3 | conv1x1-sobelx | conv1x1-sobely | conv1x1-laplacian.
16
+ in_channels (int): Channel number of input.
17
+ out_channels (int): Channel number of output.
18
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
19
+ """
20
+
21
+ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
22
+ super(SeqConv3x3, self).__init__()
23
+ self.seq_type = seq_type
24
+ self.in_channels = in_channels
25
+ self.out_channels = out_channels
26
+
27
+ if self.seq_type == 'conv1x1-conv3x3':
28
+ self.mid_planes = int(out_channels * depth_multiplier)
29
+ conv0 = torch.nn.Conv2d(self.in_channels, self.mid_planes, kernel_size=1, padding=0)
30
+ self.k0 = conv0.weight
31
+ self.b0 = conv0.bias
32
+
33
+ conv1 = torch.nn.Conv2d(self.mid_planes, self.out_channels, kernel_size=3)
34
+ self.k1 = conv1.weight
35
+ self.b1 = conv1.bias
36
+
37
+ elif self.seq_type == 'conv1x1-sobelx':
38
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
39
+ self.k0 = conv0.weight
40
+ self.b0 = conv0.bias
41
+
42
+ # init scale and bias
43
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
44
+ self.scale = nn.Parameter(scale)
45
+ bias = torch.randn(self.out_channels) * 1e-3
46
+ bias = torch.reshape(bias, (self.out_channels, ))
47
+ self.bias = nn.Parameter(bias)
48
+ # init mask
49
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
50
+ for i in range(self.out_channels):
51
+ self.mask[i, 0, 0, 0] = 1.0
52
+ self.mask[i, 0, 1, 0] = 2.0
53
+ self.mask[i, 0, 2, 0] = 1.0
54
+ self.mask[i, 0, 0, 2] = -1.0
55
+ self.mask[i, 0, 1, 2] = -2.0
56
+ self.mask[i, 0, 2, 2] = -1.0
57
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
58
+
59
+ elif self.seq_type == 'conv1x1-sobely':
60
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
61
+ self.k0 = conv0.weight
62
+ self.b0 = conv0.bias
63
+
64
+ # init scale and bias
65
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
66
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
67
+ bias = torch.randn(self.out_channels) * 1e-3
68
+ bias = torch.reshape(bias, (self.out_channels, ))
69
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
70
+ # init mask
71
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
72
+ for i in range(self.out_channels):
73
+ self.mask[i, 0, 0, 0] = 1.0
74
+ self.mask[i, 0, 0, 1] = 2.0
75
+ self.mask[i, 0, 0, 2] = 1.0
76
+ self.mask[i, 0, 2, 0] = -1.0
77
+ self.mask[i, 0, 2, 1] = -2.0
78
+ self.mask[i, 0, 2, 2] = -1.0
79
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
80
+
81
+ elif self.seq_type == 'conv1x1-laplacian':
82
+ conv0 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, padding=0)
83
+ self.k0 = conv0.weight
84
+ self.b0 = conv0.bias
85
+
86
+ # init scale and bias
87
+ scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
88
+ self.scale = nn.Parameter(torch.FloatTensor(scale))
89
+ bias = torch.randn(self.out_channels) * 1e-3
90
+ bias = torch.reshape(bias, (self.out_channels, ))
91
+ self.bias = nn.Parameter(torch.FloatTensor(bias))
92
+ # init mask
93
+ self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
94
+ for i in range(self.out_channels):
95
+ self.mask[i, 0, 0, 1] = 1.0
96
+ self.mask[i, 0, 1, 0] = 1.0
97
+ self.mask[i, 0, 1, 2] = 1.0
98
+ self.mask[i, 0, 2, 1] = 1.0
99
+ self.mask[i, 0, 1, 1] = -4.0
100
+ self.mask = nn.Parameter(data=self.mask, requires_grad=False)
101
+ else:
102
+ raise ValueError('The type of seqconv is not supported!')
103
+
104
+ def forward(self, x):
105
+ if self.seq_type == 'conv1x1-conv3x3':
106
+ # conv-1x1
107
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
108
+ # explicitly padding with bias
109
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
110
+ b0_pad = self.b0.view(1, -1, 1, 1)
111
+ y0[:, :, 0:1, :] = b0_pad
112
+ y0[:, :, -1:, :] = b0_pad
113
+ y0[:, :, :, 0:1] = b0_pad
114
+ y0[:, :, :, -1:] = b0_pad
115
+ # conv-3x3
116
+ y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1)
117
+ else:
118
+ y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1)
119
+ # explicitly padding with bias
120
+ y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0)
121
+ b0_pad = self.b0.view(1, -1, 1, 1)
122
+ y0[:, :, 0:1, :] = b0_pad
123
+ y0[:, :, -1:, :] = b0_pad
124
+ y0[:, :, :, 0:1] = b0_pad
125
+ y0[:, :, :, -1:] = b0_pad
126
+ # conv-3x3
127
+ y1 = F.conv2d(input=y0, weight=self.scale * self.mask, bias=self.bias, stride=1, groups=self.out_channels)
128
+ return y1
129
+
130
+ def rep_params(self):
131
+ device = self.k0.get_device()
132
+ if device < 0:
133
+ device = None
134
+
135
+ if self.seq_type == 'conv1x1-conv3x3':
136
+ # re-param conv kernel
137
+ rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
138
+ # re-param conv bias
139
+ rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
140
+ rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
141
+ else:
142
+ tmp = self.scale * self.mask
143
+ k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
144
+ for i in range(self.out_channels):
145
+ k1[i, i, :, :] = tmp[i, 0, :, :]
146
+ b1 = self.bias
147
+ # re-param conv kernel
148
+ rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
149
+ # re-param conv bias
150
+ rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
151
+ rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
152
+ return rep_weight, rep_bias
153
+
154
+
155
+ class ECB(nn.Module):
156
+ """The ECB block used in the ECBSR architecture.
157
+
158
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
159
+ Ref git repo: https://github.com/xindongzhang/ECBSR
160
+
161
+ Args:
162
+ in_channels (int): Channel number of input.
163
+ out_channels (int): Channel number of output.
164
+ depth_multiplier (int): Width multiplier in the expand-and-squeeze conv. Default: 1.
165
+ act_type (str): Activation type. Option: prelu | relu | rrelu | softplus | linear. Default: prelu.
166
+ with_idt (bool): Whether to use identity connection. Default: False.
167
+ """
168
+
169
+ def __init__(self, in_channels, out_channels, depth_multiplier, act_type='prelu', with_idt=False):
170
+ super(ECB, self).__init__()
171
+
172
+ self.depth_multiplier = depth_multiplier
173
+ self.in_channels = in_channels
174
+ self.out_channels = out_channels
175
+ self.act_type = act_type
176
+
177
+ if with_idt and (self.in_channels == self.out_channels):
178
+ self.with_idt = True
179
+ else:
180
+ self.with_idt = False
181
+
182
+ self.conv3x3 = torch.nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1)
183
+ self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.in_channels, self.out_channels, self.depth_multiplier)
184
+ self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.in_channels, self.out_channels)
185
+ self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.in_channels, self.out_channels)
186
+ self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.in_channels, self.out_channels)
187
+
188
+ if self.act_type == 'prelu':
189
+ self.act = nn.PReLU(num_parameters=self.out_channels)
190
+ elif self.act_type == 'relu':
191
+ self.act = nn.ReLU(inplace=True)
192
+ elif self.act_type == 'rrelu':
193
+ self.act = nn.RReLU(lower=-0.05, upper=0.05)
194
+ elif self.act_type == 'softplus':
195
+ self.act = nn.Softplus()
196
+ elif self.act_type == 'linear':
197
+ pass
198
+ else:
199
+ raise ValueError('The type of activation if not support!')
200
+
201
+ def forward(self, x):
202
+ if self.training:
203
+ y = self.conv3x3(x) + self.conv1x1_3x3(x) + self.conv1x1_sbx(x) + self.conv1x1_sby(x) + self.conv1x1_lpl(x)
204
+ if self.with_idt:
205
+ y += x
206
+ else:
207
+ rep_weight, rep_bias = self.rep_params()
208
+ y = F.conv2d(input=x, weight=rep_weight, bias=rep_bias, stride=1, padding=1)
209
+ if self.act_type != 'linear':
210
+ y = self.act(y)
211
+ return y
212
+
213
+ def rep_params(self):
214
+ weight0, bias0 = self.conv3x3.weight, self.conv3x3.bias
215
+ weight1, bias1 = self.conv1x1_3x3.rep_params()
216
+ weight2, bias2 = self.conv1x1_sbx.rep_params()
217
+ weight3, bias3 = self.conv1x1_sby.rep_params()
218
+ weight4, bias4 = self.conv1x1_lpl.rep_params()
219
+ rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
220
+ bias0 + bias1 + bias2 + bias3 + bias4)
221
+
222
+ if self.with_idt:
223
+ device = rep_weight.get_device()
224
+ if device < 0:
225
+ device = None
226
+ weight_idt = torch.zeros(self.out_channels, self.out_channels, 3, 3, device=device)
227
+ for i in range(self.out_channels):
228
+ weight_idt[i, i, 1, 1] = 1.0
229
+ bias_idt = 0.0
230
+ rep_weight, rep_bias = rep_weight + weight_idt, rep_bias + bias_idt
231
+ return rep_weight, rep_bias
232
+
233
+
234
+ @ARCH_REGISTRY.register()
235
+ class ECBSR(nn.Module):
236
+ """ECBSR architecture.
237
+
238
+ Paper: Edge-oriented Convolution Block for Real-time Super Resolution on Mobile Devices
239
+ Ref git repo: https://github.com/xindongzhang/ECBSR
240
+
241
+ Args:
242
+ num_in_ch (int): Channel number of inputs.
243
+ num_out_ch (int): Channel number of outputs.
244
+ num_block (int): Block number in the trunk network.
245
+ num_channel (int): Channel number.
246
+ with_idt (bool): Whether use identity in convolution layers.
247
+ act_type (str): Activation type.
248
+ scale (int): Upsampling factor.
249
+ """
250
+
251
+ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
252
+ super(ECBSR, self).__init__()
253
+ self.num_in_ch = num_in_ch
254
+ self.scale = scale
255
+
256
+ backbone = []
257
+ backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
258
+ for _ in range(num_block):
259
+ backbone += [ECB(num_channel, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
260
+ backbone += [
261
+ ECB(num_channel, num_out_ch * scale * scale, depth_multiplier=2.0, act_type='linear', with_idt=with_idt)
262
+ ]
263
+
264
+ self.backbone = nn.Sequential(*backbone)
265
+ self.upsampler = nn.PixelShuffle(scale)
266
+
267
+ def forward(self, x):
268
+ if self.num_in_ch > 1:
269
+ shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
270
+ else:
271
+ shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
272
+ y = self.backbone(x) + shortcut
273
+ y = self.upsampler(y)
274
+ return y
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edsr_arch.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+
4
+ from r_basicsr.archs.arch_util import ResidualBlockNoBN, Upsample, make_layer
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+
7
+
8
+ @ARCH_REGISTRY.register()
9
+ class EDSR(nn.Module):
10
+ """EDSR network structure.
11
+
12
+ Paper: Enhanced Deep Residual Networks for Single Image Super-Resolution.
13
+ Ref git repo: https://github.com/thstkdgus35/EDSR-PyTorch
14
+
15
+ Args:
16
+ num_in_ch (int): Channel number of inputs.
17
+ num_out_ch (int): Channel number of outputs.
18
+ num_feat (int): Channel number of intermediate features.
19
+ Default: 64.
20
+ num_block (int): Block number in the trunk network. Default: 16.
21
+ upscale (int): Upsampling factor. Support 2^n and 3.
22
+ Default: 4.
23
+ res_scale (float): Used to scale the residual in residual block.
24
+ Default: 1.
25
+ img_range (float): Image range. Default: 255.
26
+ rgb_mean (tuple[float]): Image mean in RGB orders.
27
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
28
+ """
29
+
30
+ def __init__(self,
31
+ num_in_ch,
32
+ num_out_ch,
33
+ num_feat=64,
34
+ num_block=16,
35
+ upscale=4,
36
+ res_scale=1,
37
+ img_range=255.,
38
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
39
+ super(EDSR, self).__init__()
40
+
41
+ self.img_range = img_range
42
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
43
+
44
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
45
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat, res_scale=res_scale, pytorch_init=True)
46
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
47
+ self.upsample = Upsample(upscale, num_feat)
48
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
49
+
50
+ def forward(self, x):
51
+ self.mean = self.mean.type_as(x)
52
+
53
+ x = (x - self.mean) * self.img_range
54
+ x = self.conv_first(x)
55
+ res = self.conv_after_body(self.body(x))
56
+ res += x
57
+
58
+ x = self.conv_last(self.upsample(res))
59
+ x = x / self.img_range + self.mean
60
+
61
+ return x
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/edvr_arch.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
7
+
8
+
9
+ class PCDAlignment(nn.Module):
10
+ """Alignment module using Pyramid, Cascading and Deformable convolution
11
+ (PCD). It is used in EDVR.
12
+
13
+ Ref:
14
+ EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
15
+
16
+ Args:
17
+ num_feat (int): Channel number of middle features. Default: 64.
18
+ deformable_groups (int): Deformable groups. Defaults: 8.
19
+ """
20
+
21
+ def __init__(self, num_feat=64, deformable_groups=8):
22
+ super(PCDAlignment, self).__init__()
23
+
24
+ # Pyramid has three levels:
25
+ # L3: level 3, 1/4 spatial size
26
+ # L2: level 2, 1/2 spatial size
27
+ # L1: level 1, original spatial size
28
+ self.offset_conv1 = nn.ModuleDict()
29
+ self.offset_conv2 = nn.ModuleDict()
30
+ self.offset_conv3 = nn.ModuleDict()
31
+ self.dcn_pack = nn.ModuleDict()
32
+ self.feat_conv = nn.ModuleDict()
33
+
34
+ # Pyramids
35
+ for i in range(3, 0, -1):
36
+ level = f'l{i}'
37
+ self.offset_conv1[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
38
+ if i == 3:
39
+ self.offset_conv2[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
40
+ else:
41
+ self.offset_conv2[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
42
+ self.offset_conv3[level] = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
43
+ self.dcn_pack[level] = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
44
+
45
+ if i < 3:
46
+ self.feat_conv[level] = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
47
+
48
+ # Cascading dcn
49
+ self.cas_offset_conv1 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
50
+ self.cas_offset_conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
51
+ self.cas_dcnpack = DCNv2Pack(num_feat, num_feat, 3, padding=1, deformable_groups=deformable_groups)
52
+
53
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
54
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
55
+
56
+ def forward(self, nbr_feat_l, ref_feat_l):
57
+ """Align neighboring frame features to the reference frame features.
58
+
59
+ Args:
60
+ nbr_feat_l (list[Tensor]): Neighboring feature list. It
61
+ contains three pyramid levels (L1, L2, L3),
62
+ each with shape (b, c, h, w).
63
+ ref_feat_l (list[Tensor]): Reference feature list. It
64
+ contains three pyramid levels (L1, L2, L3),
65
+ each with shape (b, c, h, w).
66
+
67
+ Returns:
68
+ Tensor: Aligned features.
69
+ """
70
+ # Pyramids
71
+ upsampled_offset, upsampled_feat = None, None
72
+ for i in range(3, 0, -1):
73
+ level = f'l{i}'
74
+ offset = torch.cat([nbr_feat_l[i - 1], ref_feat_l[i - 1]], dim=1)
75
+ offset = self.lrelu(self.offset_conv1[level](offset))
76
+ if i == 3:
77
+ offset = self.lrelu(self.offset_conv2[level](offset))
78
+ else:
79
+ offset = self.lrelu(self.offset_conv2[level](torch.cat([offset, upsampled_offset], dim=1)))
80
+ offset = self.lrelu(self.offset_conv3[level](offset))
81
+
82
+ feat = self.dcn_pack[level](nbr_feat_l[i - 1], offset)
83
+ if i < 3:
84
+ feat = self.feat_conv[level](torch.cat([feat, upsampled_feat], dim=1))
85
+ if i > 1:
86
+ feat = self.lrelu(feat)
87
+
88
+ if i > 1: # upsample offset and features
89
+ # x2: when we upsample the offset, we should also enlarge
90
+ # the magnitude.
91
+ upsampled_offset = self.upsample(offset) * 2
92
+ upsampled_feat = self.upsample(feat)
93
+
94
+ # Cascading
95
+ offset = torch.cat([feat, ref_feat_l[0]], dim=1)
96
+ offset = self.lrelu(self.cas_offset_conv2(self.lrelu(self.cas_offset_conv1(offset))))
97
+ feat = self.lrelu(self.cas_dcnpack(feat, offset))
98
+ return feat
99
+
100
+
101
+ class TSAFusion(nn.Module):
102
+ """Temporal Spatial Attention (TSA) fusion module.
103
+
104
+ Temporal: Calculate the correlation between center frame and
105
+ neighboring frames;
106
+ Spatial: It has 3 pyramid levels, the attention is similar to SFT.
107
+ (SFT: Recovering realistic texture in image super-resolution by deep
108
+ spatial feature transform.)
109
+
110
+ Args:
111
+ num_feat (int): Channel number of middle features. Default: 64.
112
+ num_frame (int): Number of frames. Default: 5.
113
+ center_frame_idx (int): The index of center frame. Default: 2.
114
+ """
115
+
116
+ def __init__(self, num_feat=64, num_frame=5, center_frame_idx=2):
117
+ super(TSAFusion, self).__init__()
118
+ self.center_frame_idx = center_frame_idx
119
+ # temporal attention (before fusion conv)
120
+ self.temporal_attn1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
+ self.temporal_attn2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
122
+ self.feat_fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
123
+
124
+ # spatial attention (after fusion conv)
125
+ self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
126
+ self.avg_pool = nn.AvgPool2d(3, stride=2, padding=1)
127
+ self.spatial_attn1 = nn.Conv2d(num_frame * num_feat, num_feat, 1)
128
+ self.spatial_attn2 = nn.Conv2d(num_feat * 2, num_feat, 1)
129
+ self.spatial_attn3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
130
+ self.spatial_attn4 = nn.Conv2d(num_feat, num_feat, 1)
131
+ self.spatial_attn5 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
132
+ self.spatial_attn_l1 = nn.Conv2d(num_feat, num_feat, 1)
133
+ self.spatial_attn_l2 = nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1)
134
+ self.spatial_attn_l3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
135
+ self.spatial_attn_add1 = nn.Conv2d(num_feat, num_feat, 1)
136
+ self.spatial_attn_add2 = nn.Conv2d(num_feat, num_feat, 1)
137
+
138
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
139
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
140
+
141
+ def forward(self, aligned_feat):
142
+ """
143
+ Args:
144
+ aligned_feat (Tensor): Aligned features with shape (b, t, c, h, w).
145
+
146
+ Returns:
147
+ Tensor: Features after TSA with the shape (b, c, h, w).
148
+ """
149
+ b, t, c, h, w = aligned_feat.size()
150
+ # temporal attention
151
+ embedding_ref = self.temporal_attn1(aligned_feat[:, self.center_frame_idx, :, :, :].clone())
152
+ embedding = self.temporal_attn2(aligned_feat.view(-1, c, h, w))
153
+ embedding = embedding.view(b, t, -1, h, w) # (b, t, c, h, w)
154
+
155
+ corr_l = [] # correlation list
156
+ for i in range(t):
157
+ emb_neighbor = embedding[:, i, :, :, :]
158
+ corr = torch.sum(emb_neighbor * embedding_ref, 1) # (b, h, w)
159
+ corr_l.append(corr.unsqueeze(1)) # (b, 1, h, w)
160
+ corr_prob = torch.sigmoid(torch.cat(corr_l, dim=1)) # (b, t, h, w)
161
+ corr_prob = corr_prob.unsqueeze(2).expand(b, t, c, h, w)
162
+ corr_prob = corr_prob.contiguous().view(b, -1, h, w) # (b, t*c, h, w)
163
+ aligned_feat = aligned_feat.view(b, -1, h, w) * corr_prob
164
+
165
+ # fusion
166
+ feat = self.lrelu(self.feat_fusion(aligned_feat))
167
+
168
+ # spatial attention
169
+ attn = self.lrelu(self.spatial_attn1(aligned_feat))
170
+ attn_max = self.max_pool(attn)
171
+ attn_avg = self.avg_pool(attn)
172
+ attn = self.lrelu(self.spatial_attn2(torch.cat([attn_max, attn_avg], dim=1)))
173
+ # pyramid levels
174
+ attn_level = self.lrelu(self.spatial_attn_l1(attn))
175
+ attn_max = self.max_pool(attn_level)
176
+ attn_avg = self.avg_pool(attn_level)
177
+ attn_level = self.lrelu(self.spatial_attn_l2(torch.cat([attn_max, attn_avg], dim=1)))
178
+ attn_level = self.lrelu(self.spatial_attn_l3(attn_level))
179
+ attn_level = self.upsample(attn_level)
180
+
181
+ attn = self.lrelu(self.spatial_attn3(attn)) + attn_level
182
+ attn = self.lrelu(self.spatial_attn4(attn))
183
+ attn = self.upsample(attn)
184
+ attn = self.spatial_attn5(attn)
185
+ attn_add = self.spatial_attn_add2(self.lrelu(self.spatial_attn_add1(attn)))
186
+ attn = torch.sigmoid(attn)
187
+
188
+ # after initialization, * 2 makes (attn * 2) to be close to 1.
189
+ feat = feat * attn * 2 + attn_add
190
+ return feat
191
+
192
+
193
+ class PredeblurModule(nn.Module):
194
+ """Pre-dublur module.
195
+
196
+ Args:
197
+ num_in_ch (int): Channel number of input image. Default: 3.
198
+ num_feat (int): Channel number of intermediate features. Default: 64.
199
+ hr_in (bool): Whether the input has high resolution. Default: False.
200
+ """
201
+
202
+ def __init__(self, num_in_ch=3, num_feat=64, hr_in=False):
203
+ super(PredeblurModule, self).__init__()
204
+ self.hr_in = hr_in
205
+
206
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
207
+ if self.hr_in:
208
+ # downsample x4 by stride conv
209
+ self.stride_conv_hr1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
210
+ self.stride_conv_hr2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
211
+
212
+ # generate feature pyramid
213
+ self.stride_conv_l2 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
214
+ self.stride_conv_l3 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
215
+
216
+ self.resblock_l3 = ResidualBlockNoBN(num_feat=num_feat)
217
+ self.resblock_l2_1 = ResidualBlockNoBN(num_feat=num_feat)
218
+ self.resblock_l2_2 = ResidualBlockNoBN(num_feat=num_feat)
219
+ self.resblock_l1 = nn.ModuleList([ResidualBlockNoBN(num_feat=num_feat) for i in range(5)])
220
+
221
+ self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
222
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
223
+
224
+ def forward(self, x):
225
+ feat_l1 = self.lrelu(self.conv_first(x))
226
+ if self.hr_in:
227
+ feat_l1 = self.lrelu(self.stride_conv_hr1(feat_l1))
228
+ feat_l1 = self.lrelu(self.stride_conv_hr2(feat_l1))
229
+
230
+ # generate feature pyramid
231
+ feat_l2 = self.lrelu(self.stride_conv_l2(feat_l1))
232
+ feat_l3 = self.lrelu(self.stride_conv_l3(feat_l2))
233
+
234
+ feat_l3 = self.upsample(self.resblock_l3(feat_l3))
235
+ feat_l2 = self.resblock_l2_1(feat_l2) + feat_l3
236
+ feat_l2 = self.upsample(self.resblock_l2_2(feat_l2))
237
+
238
+ for i in range(2):
239
+ feat_l1 = self.resblock_l1[i](feat_l1)
240
+ feat_l1 = feat_l1 + feat_l2
241
+ for i in range(2, 5):
242
+ feat_l1 = self.resblock_l1[i](feat_l1)
243
+ return feat_l1
244
+
245
+
246
+ @ARCH_REGISTRY.register()
247
+ class EDVR(nn.Module):
248
+ """EDVR network structure for video super-resolution.
249
+
250
+ Now only support X4 upsampling factor.
251
+ Paper:
252
+ EDVR: Video Restoration with Enhanced Deformable Convolutional Networks
253
+
254
+ Args:
255
+ num_in_ch (int): Channel number of input image. Default: 3.
256
+ num_out_ch (int): Channel number of output image. Default: 3.
257
+ num_feat (int): Channel number of intermediate features. Default: 64.
258
+ num_frame (int): Number of input frames. Default: 5.
259
+ deformable_groups (int): Deformable groups. Defaults: 8.
260
+ num_extract_block (int): Number of blocks for feature extraction.
261
+ Default: 5.
262
+ num_reconstruct_block (int): Number of blocks for reconstruction.
263
+ Default: 10.
264
+ center_frame_idx (int): The index of center frame. Frame counting from
265
+ 0. Default: Middle of input frames.
266
+ hr_in (bool): Whether the input has high resolution. Default: False.
267
+ with_predeblur (bool): Whether has predeblur module.
268
+ Default: False.
269
+ with_tsa (bool): Whether has TSA module. Default: True.
270
+ """
271
+
272
+ def __init__(self,
273
+ num_in_ch=3,
274
+ num_out_ch=3,
275
+ num_feat=64,
276
+ num_frame=5,
277
+ deformable_groups=8,
278
+ num_extract_block=5,
279
+ num_reconstruct_block=10,
280
+ center_frame_idx=None,
281
+ hr_in=False,
282
+ with_predeblur=False,
283
+ with_tsa=True):
284
+ super(EDVR, self).__init__()
285
+ if center_frame_idx is None:
286
+ self.center_frame_idx = num_frame // 2
287
+ else:
288
+ self.center_frame_idx = center_frame_idx
289
+ self.hr_in = hr_in
290
+ self.with_predeblur = with_predeblur
291
+ self.with_tsa = with_tsa
292
+
293
+ # extract features for each frame
294
+ if self.with_predeblur:
295
+ self.predeblur = PredeblurModule(num_feat=num_feat, hr_in=self.hr_in)
296
+ self.conv_1x1 = nn.Conv2d(num_feat, num_feat, 1, 1)
297
+ else:
298
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
299
+
300
+ # extract pyramid features
301
+ self.feature_extraction = make_layer(ResidualBlockNoBN, num_extract_block, num_feat=num_feat)
302
+ self.conv_l2_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
303
+ self.conv_l2_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
304
+ self.conv_l3_1 = nn.Conv2d(num_feat, num_feat, 3, 2, 1)
305
+ self.conv_l3_2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
306
+
307
+ # pcd and tsa module
308
+ self.pcd_align = PCDAlignment(num_feat=num_feat, deformable_groups=deformable_groups)
309
+ if self.with_tsa:
310
+ self.fusion = TSAFusion(num_feat=num_feat, num_frame=num_frame, center_frame_idx=self.center_frame_idx)
311
+ else:
312
+ self.fusion = nn.Conv2d(num_frame * num_feat, num_feat, 1, 1)
313
+
314
+ # reconstruction
315
+ self.reconstruction = make_layer(ResidualBlockNoBN, num_reconstruct_block, num_feat=num_feat)
316
+ # upsample
317
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
318
+ self.upconv2 = nn.Conv2d(num_feat, 64 * 4, 3, 1, 1)
319
+ self.pixel_shuffle = nn.PixelShuffle(2)
320
+ self.conv_hr = nn.Conv2d(64, 64, 3, 1, 1)
321
+ self.conv_last = nn.Conv2d(64, 3, 3, 1, 1)
322
+
323
+ # activation function
324
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
325
+
326
+ def forward(self, x):
327
+ b, t, c, h, w = x.size()
328
+ if self.hr_in:
329
+ assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
330
+ else:
331
+ assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
332
+
333
+ x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
334
+
335
+ # extract features for each frame
336
+ # L1
337
+ if self.with_predeblur:
338
+ feat_l1 = self.conv_1x1(self.predeblur(x.view(-1, c, h, w)))
339
+ if self.hr_in:
340
+ h, w = h // 4, w // 4
341
+ else:
342
+ feat_l1 = self.lrelu(self.conv_first(x.view(-1, c, h, w)))
343
+
344
+ feat_l1 = self.feature_extraction(feat_l1)
345
+ # L2
346
+ feat_l2 = self.lrelu(self.conv_l2_1(feat_l1))
347
+ feat_l2 = self.lrelu(self.conv_l2_2(feat_l2))
348
+ # L3
349
+ feat_l3 = self.lrelu(self.conv_l3_1(feat_l2))
350
+ feat_l3 = self.lrelu(self.conv_l3_2(feat_l3))
351
+
352
+ feat_l1 = feat_l1.view(b, t, -1, h, w)
353
+ feat_l2 = feat_l2.view(b, t, -1, h // 2, w // 2)
354
+ feat_l3 = feat_l3.view(b, t, -1, h // 4, w // 4)
355
+
356
+ # PCD alignment
357
+ ref_feat_l = [ # reference feature list
358
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
359
+ feat_l3[:, self.center_frame_idx, :, :, :].clone()
360
+ ]
361
+ aligned_feat = []
362
+ for i in range(t):
363
+ nbr_feat_l = [ # neighboring feature list
364
+ feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
365
+ ]
366
+ aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
367
+ aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
368
+
369
+ if not self.with_tsa:
370
+ aligned_feat = aligned_feat.view(b, -1, h, w)
371
+ feat = self.fusion(aligned_feat)
372
+
373
+ out = self.reconstruction(feat)
374
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
375
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
376
+ out = self.lrelu(self.conv_hr(out))
377
+ out = self.conv_last(out)
378
+ if self.hr_in:
379
+ base = x_center
380
+ else:
381
+ base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
382
+ out += base
383
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_arch.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from r_basicsr.utils.registry import ARCH_REGISTRY
7
+ from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
8
+
9
+
10
+ class SPADEGenerator(BaseNetwork):
11
+ """Generator with SPADEResBlock"""
12
+
13
+ def __init__(self,
14
+ num_in_ch=3,
15
+ num_feat=64,
16
+ use_vae=False,
17
+ z_dim=256,
18
+ crop_size=512,
19
+ norm_g='spectralspadesyncbatch3x3',
20
+ is_train=True,
21
+ init_train_phase=3): # progressive training disabled
22
+ super().__init__()
23
+ self.nf = num_feat
24
+ self.input_nc = num_in_ch
25
+ self.is_train = is_train
26
+ self.train_phase = init_train_phase
27
+
28
+ self.scale_ratio = 5 # hardcoded now
29
+ self.sw = crop_size // (2**self.scale_ratio)
30
+ self.sh = self.sw # 20210519: By default use square image, aspect_ratio = 1.0
31
+
32
+ if use_vae:
33
+ # In case of VAE, we will sample from random z vector
34
+ self.fc = nn.Linear(z_dim, 16 * self.nf * self.sw * self.sh)
35
+ else:
36
+ # Otherwise, we make the network deterministic by starting with
37
+ # downsampled segmentation map instead of random z
38
+ self.fc = nn.Conv2d(num_in_ch, 16 * self.nf, 3, padding=1)
39
+
40
+ self.head_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
41
+
42
+ self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
43
+ self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
44
+
45
+ self.ups = nn.ModuleList([
46
+ SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
47
+ SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
48
+ SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
49
+ SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
50
+ ])
51
+
52
+ self.to_rgbs = nn.ModuleList([
53
+ nn.Conv2d(8 * self.nf, 3, 3, padding=1),
54
+ nn.Conv2d(4 * self.nf, 3, 3, padding=1),
55
+ nn.Conv2d(2 * self.nf, 3, 3, padding=1),
56
+ nn.Conv2d(1 * self.nf, 3, 3, padding=1)
57
+ ])
58
+
59
+ self.up = nn.Upsample(scale_factor=2)
60
+
61
+ def encode(self, input_tensor):
62
+ """
63
+ Encode input_tensor into feature maps, can be overridden in derived classes
64
+ Default: nearest downsampling of 2**5 = 32 times
65
+ """
66
+ h, w = input_tensor.size()[-2:]
67
+ sh, sw = h // 2**self.scale_ratio, w // 2**self.scale_ratio
68
+ x = F.interpolate(input_tensor, size=(sh, sw))
69
+ return self.fc(x)
70
+
71
+ def forward(self, x):
72
+ # In oroginal SPADE, seg means a segmentation map, but here we use x instead.
73
+ seg = x
74
+
75
+ x = self.encode(x)
76
+ x = self.head_0(x, seg)
77
+
78
+ x = self.up(x)
79
+ x = self.g_middle_0(x, seg)
80
+ x = self.g_middle_1(x, seg)
81
+
82
+ if self.is_train:
83
+ phase = self.train_phase + 1
84
+ else:
85
+ phase = len(self.to_rgbs)
86
+
87
+ for i in range(phase):
88
+ x = self.up(x)
89
+ x = self.ups[i](x, seg)
90
+
91
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
92
+ x = torch.tanh(x)
93
+
94
+ return x
95
+
96
+ def mixed_guidance_forward(self, input_x, seg=None, n=0, mode='progressive'):
97
+ """
98
+ A helper class for subspace visualization. Input and seg are different images.
99
+ For the first n levels (including encoder) we use input, for the rest we use seg.
100
+
101
+ If mode = 'progressive', the output's like: AAABBB
102
+ If mode = 'one_plug', the output's like: AAABAA
103
+ If mode = 'one_ablate', the output's like: BBBABB
104
+ """
105
+
106
+ if seg is None:
107
+ return self.forward(input_x)
108
+
109
+ if self.is_train:
110
+ phase = self.train_phase + 1
111
+ else:
112
+ phase = len(self.to_rgbs)
113
+
114
+ if mode == 'progressive':
115
+ n = max(min(n, 4 + phase), 0)
116
+ guide_list = [input_x] * n + [seg] * (4 + phase - n)
117
+ elif mode == 'one_plug':
118
+ n = max(min(n, 4 + phase - 1), 0)
119
+ guide_list = [seg] * (4 + phase)
120
+ guide_list[n] = input_x
121
+ elif mode == 'one_ablate':
122
+ if n > 3 + phase:
123
+ return self.forward(input_x)
124
+ guide_list = [input_x] * (4 + phase)
125
+ guide_list[n] = seg
126
+
127
+ x = self.encode(guide_list[0])
128
+ x = self.head_0(x, guide_list[1])
129
+
130
+ x = self.up(x)
131
+ x = self.g_middle_0(x, guide_list[2])
132
+ x = self.g_middle_1(x, guide_list[3])
133
+
134
+ for i in range(phase):
135
+ x = self.up(x)
136
+ x = self.ups[i](x, guide_list[4 + i])
137
+
138
+ x = self.to_rgbs[phase - 1](F.leaky_relu(x, 2e-1))
139
+ x = torch.tanh(x)
140
+
141
+ return x
142
+
143
+
144
+ @ARCH_REGISTRY.register()
145
+ class HiFaceGAN(SPADEGenerator):
146
+ """
147
+ HiFaceGAN: SPADEGenerator with a learnable feature encoder
148
+ Current encoder design: LIPEncoder
149
+ """
150
+
151
+ def __init__(self,
152
+ num_in_ch=3,
153
+ num_feat=64,
154
+ use_vae=False,
155
+ z_dim=256,
156
+ crop_size=512,
157
+ norm_g='spectralspadesyncbatch3x3',
158
+ is_train=True,
159
+ init_train_phase=3):
160
+ super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
161
+ self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
162
+
163
+ def encode(self, input_tensor):
164
+ return self.lip_encoder(input_tensor)
165
+
166
+
167
+ @ARCH_REGISTRY.register()
168
+ class HiFaceGANDiscriminator(BaseNetwork):
169
+ """
170
+ Inspired by pix2pixHD multiscale discriminator.
171
+ Args:
172
+ num_in_ch (int): Channel number of inputs. Default: 3.
173
+ num_out_ch (int): Channel number of outputs. Default: 3.
174
+ conditional_d (bool): Whether use conditional discriminator.
175
+ Default: True.
176
+ num_d (int): Number of Multiscale discriminators. Default: 3.
177
+ n_layers_d (int): Number of downsample layers in each D. Default: 4.
178
+ num_feat (int): Channel number of base intermediate features.
179
+ Default: 64.
180
+ norm_d (str): String to determine normalization layers in D.
181
+ Choices: [spectral][instance/batch/syncbatch]
182
+ Default: 'spectralinstance'.
183
+ keep_features (bool): Keep intermediate features for matching loss, etc.
184
+ Default: True.
185
+ """
186
+
187
+ def __init__(self,
188
+ num_in_ch=3,
189
+ num_out_ch=3,
190
+ conditional_d=True,
191
+ num_d=2,
192
+ n_layers_d=4,
193
+ num_feat=64,
194
+ norm_d='spectralinstance',
195
+ keep_features=True):
196
+ super().__init__()
197
+ self.num_d = num_d
198
+
199
+ input_nc = num_in_ch
200
+ if conditional_d:
201
+ input_nc += num_out_ch
202
+
203
+ for i in range(num_d):
204
+ subnet_d = NLayerDiscriminator(input_nc, n_layers_d, num_feat, norm_d, keep_features)
205
+ self.add_module(f'discriminator_{i}', subnet_d)
206
+
207
+ def downsample(self, x):
208
+ return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
209
+
210
+ # Returns list of lists of discriminator outputs.
211
+ # The final result is of size opt.num_d x opt.n_layers_D
212
+ def forward(self, x):
213
+ result = []
214
+ for _, _net_d in self.named_children():
215
+ out = _net_d(x)
216
+ result.append(out)
217
+ x = self.downsample(x)
218
+
219
+ return result
220
+
221
+
222
+ class NLayerDiscriminator(BaseNetwork):
223
+ """Defines the PatchGAN discriminator with the specified arguments."""
224
+
225
+ def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
226
+ super().__init__()
227
+ kw = 4
228
+ padw = int(np.ceil((kw - 1.0) / 2))
229
+ nf = num_feat
230
+ self.keep_features = keep_features
231
+
232
+ norm_layer = get_nonspade_norm_layer(norm_d)
233
+ sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)]]
234
+
235
+ for n in range(1, n_layers_d):
236
+ nf_prev = nf
237
+ nf = min(nf * 2, 512)
238
+ stride = 1 if n == n_layers_d - 1 else 2
239
+ sequence += [[
240
+ norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
241
+ nn.LeakyReLU(0.2, False)
242
+ ]]
243
+
244
+ sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
245
+
246
+ # We divide the layers into groups to extract intermediate layer outputs
247
+ for n in range(len(sequence)):
248
+ self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
249
+
250
+ def forward(self, x):
251
+ results = [x]
252
+ for submodel in self.children():
253
+ intermediate_output = submodel(results[-1])
254
+ results.append(intermediate_output)
255
+
256
+ if self.keep_features:
257
+ return results[1:]
258
+ else:
259
+ return results[-1]
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/hifacegan_util.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ # Warning: spectral norm could be buggy
7
+ # under eval mode and multi-GPU inference
8
+ # A workaround is sticking to single-GPU inference and train mode
9
+ from torch.nn.utils import spectral_norm
10
+
11
+
12
+ class SPADE(nn.Module):
13
+
14
+ def __init__(self, config_text, norm_nc, label_nc):
15
+ super().__init__()
16
+
17
+ assert config_text.startswith('spade')
18
+ parsed = re.search('spade(\\D+)(\\d)x\\d', config_text)
19
+ param_free_norm_type = str(parsed.group(1))
20
+ ks = int(parsed.group(2))
21
+
22
+ if param_free_norm_type == 'instance':
23
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
24
+ elif param_free_norm_type == 'syncbatch':
25
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
26
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc)
27
+ elif param_free_norm_type == 'batch':
28
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
29
+ else:
30
+ raise ValueError(f'{param_free_norm_type} is not a recognized param-free norm type in SPADE')
31
+
32
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
33
+ nhidden = 128 if norm_nc > 128 else norm_nc
34
+
35
+ pw = ks // 2
36
+ self.mlp_shared = nn.Sequential(nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
37
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
38
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw, bias=False)
39
+
40
+ def forward(self, x, segmap):
41
+
42
+ # Part 1. generate parameter-free normalized activations
43
+ normalized = self.param_free_norm(x)
44
+
45
+ # Part 2. produce scaling and bias conditioned on semantic map
46
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
47
+ actv = self.mlp_shared(segmap)
48
+ gamma = self.mlp_gamma(actv)
49
+ beta = self.mlp_beta(actv)
50
+
51
+ # apply scale and bias
52
+ out = normalized * gamma + beta
53
+
54
+ return out
55
+
56
+
57
+ class SPADEResnetBlock(nn.Module):
58
+ """
59
+ ResNet block that uses SPADE. It differs from the ResNet block of pix2pixHD in that
60
+ it takes in the segmentation map as input, learns the skip connection if necessary,
61
+ and applies normalization first and then convolution.
62
+ This architecture seemed like a standard architecture for unconditional or
63
+ class-conditional GAN architecture using residual block.
64
+ The code was inspired from https://github.com/LMescheder/GAN_stability.
65
+ """
66
+
67
+ def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
68
+ super().__init__()
69
+ # Attributes
70
+ self.learned_shortcut = (fin != fout)
71
+ fmiddle = min(fin, fout)
72
+
73
+ # create conv layers
74
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
75
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
76
+ if self.learned_shortcut:
77
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
78
+
79
+ # apply spectral norm if specified
80
+ if 'spectral' in norm_g:
81
+ self.conv_0 = spectral_norm(self.conv_0)
82
+ self.conv_1 = spectral_norm(self.conv_1)
83
+ if self.learned_shortcut:
84
+ self.conv_s = spectral_norm(self.conv_s)
85
+
86
+ # define normalization layers
87
+ spade_config_str = norm_g.replace('spectral', '')
88
+ self.norm_0 = SPADE(spade_config_str, fin, semantic_nc)
89
+ self.norm_1 = SPADE(spade_config_str, fmiddle, semantic_nc)
90
+ if self.learned_shortcut:
91
+ self.norm_s = SPADE(spade_config_str, fin, semantic_nc)
92
+
93
+ # note the resnet block with SPADE also takes in |seg|,
94
+ # the semantic segmentation map as input
95
+ def forward(self, x, seg):
96
+ x_s = self.shortcut(x, seg)
97
+ dx = self.conv_0(self.act(self.norm_0(x, seg)))
98
+ dx = self.conv_1(self.act(self.norm_1(dx, seg)))
99
+ out = x_s + dx
100
+ return out
101
+
102
+ def shortcut(self, x, seg):
103
+ if self.learned_shortcut:
104
+ x_s = self.conv_s(self.norm_s(x, seg))
105
+ else:
106
+ x_s = x
107
+ return x_s
108
+
109
+ def act(self, x):
110
+ return F.leaky_relu(x, 2e-1)
111
+
112
+
113
+ class BaseNetwork(nn.Module):
114
+ """ A basis for hifacegan archs with custom initialization """
115
+
116
+ def init_weights(self, init_type='normal', gain=0.02):
117
+
118
+ def init_func(m):
119
+ classname = m.__class__.__name__
120
+ if classname.find('BatchNorm2d') != -1:
121
+ if hasattr(m, 'weight') and m.weight is not None:
122
+ init.normal_(m.weight.data, 1.0, gain)
123
+ if hasattr(m, 'bias') and m.bias is not None:
124
+ init.constant_(m.bias.data, 0.0)
125
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
126
+ if init_type == 'normal':
127
+ init.normal_(m.weight.data, 0.0, gain)
128
+ elif init_type == 'xavier':
129
+ init.xavier_normal_(m.weight.data, gain=gain)
130
+ elif init_type == 'xavier_uniform':
131
+ init.xavier_uniform_(m.weight.data, gain=1.0)
132
+ elif init_type == 'kaiming':
133
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
134
+ elif init_type == 'orthogonal':
135
+ init.orthogonal_(m.weight.data, gain=gain)
136
+ elif init_type == 'none': # uses pytorch's default init method
137
+ m.reset_parameters()
138
+ else:
139
+ raise NotImplementedError(f'initialization method [{init_type}] is not implemented')
140
+ if hasattr(m, 'bias') and m.bias is not None:
141
+ init.constant_(m.bias.data, 0.0)
142
+
143
+ self.apply(init_func)
144
+
145
+ # propagate to children
146
+ for m in self.children():
147
+ if hasattr(m, 'init_weights'):
148
+ m.init_weights(init_type, gain)
149
+
150
+ def forward(self, x):
151
+ pass
152
+
153
+
154
+ def lip2d(x, logit, kernel=3, stride=2, padding=1):
155
+ weight = logit.exp()
156
+ return F.avg_pool2d(x * weight, kernel, stride, padding) / F.avg_pool2d(weight, kernel, stride, padding)
157
+
158
+
159
+ class SoftGate(nn.Module):
160
+ COEFF = 12.0
161
+
162
+ def forward(self, x):
163
+ return torch.sigmoid(x).mul(self.COEFF)
164
+
165
+
166
+ class SimplifiedLIP(nn.Module):
167
+
168
+ def __init__(self, channels):
169
+ super(SimplifiedLIP, self).__init__()
170
+ self.logit = nn.Sequential(
171
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
172
+ SoftGate())
173
+
174
+ def init_layer(self):
175
+ self.logit[0].weight.data.fill_(0.0)
176
+
177
+ def forward(self, x):
178
+ frac = lip2d(x, self.logit(x))
179
+ return frac
180
+
181
+
182
+ class LIPEncoder(BaseNetwork):
183
+ """Local Importance-based Pooling (Ziteng Gao et.al.,ICCV 2019)"""
184
+
185
+ def __init__(self, input_nc, ngf, sw, sh, n_2xdown, norm_layer=nn.InstanceNorm2d):
186
+ super().__init__()
187
+ self.sw = sw
188
+ self.sh = sh
189
+ self.max_ratio = 16
190
+ # 20200310: Several Convolution (stride 1) + LIP blocks, 4 fold
191
+ kw = 3
192
+ pw = (kw - 1) // 2
193
+
194
+ model = [
195
+ nn.Conv2d(input_nc, ngf, kw, stride=1, padding=pw, bias=False),
196
+ norm_layer(ngf),
197
+ nn.ReLU(),
198
+ ]
199
+ cur_ratio = 1
200
+ for i in range(n_2xdown):
201
+ next_ratio = min(cur_ratio * 2, self.max_ratio)
202
+ model += [
203
+ SimplifiedLIP(ngf * cur_ratio),
204
+ nn.Conv2d(ngf * cur_ratio, ngf * next_ratio, kw, stride=1, padding=pw),
205
+ norm_layer(ngf * next_ratio),
206
+ ]
207
+ cur_ratio = next_ratio
208
+ if i < n_2xdown - 1:
209
+ model += [nn.ReLU(inplace=True)]
210
+
211
+ self.model = nn.Sequential(*model)
212
+
213
+ def forward(self, x):
214
+ return self.model(x)
215
+
216
+
217
+ def get_nonspade_norm_layer(norm_type='instance'):
218
+ # helper function to get # output channels of the previous layer
219
+ def get_out_channel(layer):
220
+ if hasattr(layer, 'out_channels'):
221
+ return getattr(layer, 'out_channels')
222
+ return layer.weight.size(0)
223
+
224
+ # this function will be returned
225
+ def add_norm_layer(layer):
226
+ nonlocal norm_type
227
+ if norm_type.startswith('spectral'):
228
+ layer = spectral_norm(layer)
229
+ subnorm_type = norm_type[len('spectral'):]
230
+
231
+ if subnorm_type == 'none' or len(subnorm_type) == 0:
232
+ return layer
233
+
234
+ # remove bias in the previous layer, which is meaningless
235
+ # since it has no effect after normalization
236
+ if getattr(layer, 'bias', None) is not None:
237
+ delattr(layer, 'bias')
238
+ layer.register_parameter('bias', None)
239
+
240
+ if subnorm_type == 'batch':
241
+ norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
242
+ elif subnorm_type == 'sync_batch':
243
+ print('SyncBatchNorm is currently not supported under single-GPU mode, switch to "instance" instead')
244
+ # norm_layer = SynchronizedBatchNorm2d(
245
+ # get_out_channel(layer), affine=True)
246
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
247
+ elif subnorm_type == 'instance':
248
+ norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
249
+ else:
250
+ raise ValueError(f'normalization layer {subnorm_type} is not recognized')
251
+
252
+ return nn.Sequential(layer, norm_layer)
253
+
254
+ print('This is a legacy from nvlabs/SPADE, and will be removed in future versions.')
255
+ return add_norm_layer
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/inception.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/mseitzer/pytorch-fid/blob/master/pytorch_fid/inception.py # noqa: E501
2
+ # For FID metric
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.model_zoo import load_url
9
+ from torchvision import models
10
+
11
+ # Inception weights ported to Pytorch from
12
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
13
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
14
+ LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
15
+
16
+
17
+ class InceptionV3(nn.Module):
18
+ """Pretrained InceptionV3 network returning feature maps"""
19
+
20
+ # Index of default block of inception to return,
21
+ # corresponds to output of final average pooling
22
+ DEFAULT_BLOCK_INDEX = 3
23
+
24
+ # Maps feature dimensionality to their output blocks indices
25
+ BLOCK_INDEX_BY_DIM = {
26
+ 64: 0, # First max pooling features
27
+ 192: 1, # Second max pooling features
28
+ 768: 2, # Pre-aux classifier features
29
+ 2048: 3 # Final average pooling features
30
+ }
31
+
32
+ def __init__(self,
33
+ output_blocks=(DEFAULT_BLOCK_INDEX),
34
+ resize_input=True,
35
+ normalize_input=True,
36
+ requires_grad=False,
37
+ use_fid_inception=True):
38
+ """Build pretrained InceptionV3.
39
+
40
+ Args:
41
+ output_blocks (list[int]): Indices of blocks to return features of.
42
+ Possible values are:
43
+ - 0: corresponds to output of first max pooling
44
+ - 1: corresponds to output of second max pooling
45
+ - 2: corresponds to output which is fed to aux classifier
46
+ - 3: corresponds to output of final average pooling
47
+ resize_input (bool): If true, bilinearly resizes input to width and
48
+ height 299 before feeding input to model. As the network
49
+ without fully connected layers is fully convolutional, it
50
+ should be able to handle inputs of arbitrary size, so resizing
51
+ might not be strictly needed. Default: True.
52
+ normalize_input (bool): If true, scales the input from range (0, 1)
53
+ to the range the pretrained Inception network expects,
54
+ namely (-1, 1). Default: True.
55
+ requires_grad (bool): If true, parameters of the model require
56
+ gradients. Possibly useful for finetuning the network.
57
+ Default: False.
58
+ use_fid_inception (bool): If true, uses the pretrained Inception
59
+ model used in Tensorflow's FID implementation.
60
+ If false, uses the pretrained Inception model available in
61
+ torchvision. The FID Inception model has different weights
62
+ and a slightly different structure from torchvision's
63
+ Inception model. If you want to compute FID scores, you are
64
+ strongly advised to set this parameter to true to get
65
+ comparable results. Default: True.
66
+ """
67
+ super(InceptionV3, self).__init__()
68
+
69
+ self.resize_input = resize_input
70
+ self.normalize_input = normalize_input
71
+ self.output_blocks = sorted(output_blocks)
72
+ self.last_needed_block = max(output_blocks)
73
+
74
+ assert self.last_needed_block <= 3, ('Last possible output block index is 3')
75
+
76
+ self.blocks = nn.ModuleList()
77
+
78
+ if use_fid_inception:
79
+ inception = fid_inception_v3()
80
+ else:
81
+ try:
82
+ inception = models.inception_v3(pretrained=True, init_weights=False)
83
+ except TypeError:
84
+ # pytorch < 1.5 does not have init_weights for inception_v3
85
+ inception = models.inception_v3(pretrained=True)
86
+
87
+ # Block 0: input to maxpool1
88
+ block0 = [
89
+ inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
90
+ nn.MaxPool2d(kernel_size=3, stride=2)
91
+ ]
92
+ self.blocks.append(nn.Sequential(*block0))
93
+
94
+ # Block 1: maxpool1 to maxpool2
95
+ if self.last_needed_block >= 1:
96
+ block1 = [inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, nn.MaxPool2d(kernel_size=3, stride=2)]
97
+ self.blocks.append(nn.Sequential(*block1))
98
+
99
+ # Block 2: maxpool2 to aux classifier
100
+ if self.last_needed_block >= 2:
101
+ block2 = [
102
+ inception.Mixed_5b,
103
+ inception.Mixed_5c,
104
+ inception.Mixed_5d,
105
+ inception.Mixed_6a,
106
+ inception.Mixed_6b,
107
+ inception.Mixed_6c,
108
+ inception.Mixed_6d,
109
+ inception.Mixed_6e,
110
+ ]
111
+ self.blocks.append(nn.Sequential(*block2))
112
+
113
+ # Block 3: aux classifier to final avgpool
114
+ if self.last_needed_block >= 3:
115
+ block3 = [
116
+ inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
117
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
118
+ ]
119
+ self.blocks.append(nn.Sequential(*block3))
120
+
121
+ for param in self.parameters():
122
+ param.requires_grad = requires_grad
123
+
124
+ def forward(self, x):
125
+ """Get Inception feature maps.
126
+
127
+ Args:
128
+ x (Tensor): Input tensor of shape (b, 3, h, w).
129
+ Values are expected to be in range (-1, 1). You can also input
130
+ (0, 1) with setting normalize_input = True.
131
+
132
+ Returns:
133
+ list[Tensor]: Corresponding to the selected output block, sorted
134
+ ascending by index.
135
+ """
136
+ output = []
137
+
138
+ if self.resize_input:
139
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
140
+
141
+ if self.normalize_input:
142
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
143
+
144
+ for idx, block in enumerate(self.blocks):
145
+ x = block(x)
146
+ if idx in self.output_blocks:
147
+ output.append(x)
148
+
149
+ if idx == self.last_needed_block:
150
+ break
151
+
152
+ return output
153
+
154
+
155
+ def fid_inception_v3():
156
+ """Build pretrained Inception model for FID computation.
157
+
158
+ The Inception model for FID computation uses a different set of weights
159
+ and has a slightly different structure than torchvision's Inception.
160
+
161
+ This method first constructs torchvision's Inception and then patches the
162
+ necessary parts that are different in the FID Inception model.
163
+ """
164
+ try:
165
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False, init_weights=False)
166
+ except TypeError:
167
+ # pytorch < 1.5 does not have init_weights for inception_v3
168
+ inception = models.inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
169
+
170
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
171
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
172
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
173
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
174
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
175
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
176
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
177
+ inception.Mixed_7b = FIDInceptionE_1(1280)
178
+ inception.Mixed_7c = FIDInceptionE_2(2048)
179
+
180
+ if os.path.exists(LOCAL_FID_WEIGHTS):
181
+ state_dict = torch.load(LOCAL_FID_WEIGHTS, map_location=lambda storage, loc: storage)
182
+ else:
183
+ state_dict = load_url(FID_WEIGHTS_URL, progress=True)
184
+
185
+ inception.load_state_dict(state_dict)
186
+ return inception
187
+
188
+
189
+ class FIDInceptionA(models.inception.InceptionA):
190
+ """InceptionA block patched for FID computation"""
191
+
192
+ def __init__(self, in_channels, pool_features):
193
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
194
+
195
+ def forward(self, x):
196
+ branch1x1 = self.branch1x1(x)
197
+
198
+ branch5x5 = self.branch5x5_1(x)
199
+ branch5x5 = self.branch5x5_2(branch5x5)
200
+
201
+ branch3x3dbl = self.branch3x3dbl_1(x)
202
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
203
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
204
+
205
+ # Patch: Tensorflow's average pool does not use the padded zero's in
206
+ # its average calculation
207
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
208
+ branch_pool = self.branch_pool(branch_pool)
209
+
210
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
211
+ return torch.cat(outputs, 1)
212
+
213
+
214
+ class FIDInceptionC(models.inception.InceptionC):
215
+ """InceptionC block patched for FID computation"""
216
+
217
+ def __init__(self, in_channels, channels_7x7):
218
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
219
+
220
+ def forward(self, x):
221
+ branch1x1 = self.branch1x1(x)
222
+
223
+ branch7x7 = self.branch7x7_1(x)
224
+ branch7x7 = self.branch7x7_2(branch7x7)
225
+ branch7x7 = self.branch7x7_3(branch7x7)
226
+
227
+ branch7x7dbl = self.branch7x7dbl_1(x)
228
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
229
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
230
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
231
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
232
+
233
+ # Patch: Tensorflow's average pool does not use the padded zero's in
234
+ # its average calculation
235
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
236
+ branch_pool = self.branch_pool(branch_pool)
237
+
238
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
239
+ return torch.cat(outputs, 1)
240
+
241
+
242
+ class FIDInceptionE_1(models.inception.InceptionE):
243
+ """First InceptionE block patched for FID computation"""
244
+
245
+ def __init__(self, in_channels):
246
+ super(FIDInceptionE_1, self).__init__(in_channels)
247
+
248
+ def forward(self, x):
249
+ branch1x1 = self.branch1x1(x)
250
+
251
+ branch3x3 = self.branch3x3_1(x)
252
+ branch3x3 = [
253
+ self.branch3x3_2a(branch3x3),
254
+ self.branch3x3_2b(branch3x3),
255
+ ]
256
+ branch3x3 = torch.cat(branch3x3, 1)
257
+
258
+ branch3x3dbl = self.branch3x3dbl_1(x)
259
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
260
+ branch3x3dbl = [
261
+ self.branch3x3dbl_3a(branch3x3dbl),
262
+ self.branch3x3dbl_3b(branch3x3dbl),
263
+ ]
264
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
265
+
266
+ # Patch: Tensorflow's average pool does not use the padded zero's in
267
+ # its average calculation
268
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, count_include_pad=False)
269
+ branch_pool = self.branch_pool(branch_pool)
270
+
271
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
272
+ return torch.cat(outputs, 1)
273
+
274
+
275
+ class FIDInceptionE_2(models.inception.InceptionE):
276
+ """Second InceptionE block patched for FID computation"""
277
+
278
+ def __init__(self, in_channels):
279
+ super(FIDInceptionE_2, self).__init__(in_channels)
280
+
281
+ def forward(self, x):
282
+ branch1x1 = self.branch1x1(x)
283
+
284
+ branch3x3 = self.branch3x3_1(x)
285
+ branch3x3 = [
286
+ self.branch3x3_2a(branch3x3),
287
+ self.branch3x3_2b(branch3x3),
288
+ ]
289
+ branch3x3 = torch.cat(branch3x3, 1)
290
+
291
+ branch3x3dbl = self.branch3x3dbl_1(x)
292
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
293
+ branch3x3dbl = [
294
+ self.branch3x3dbl_3a(branch3x3dbl),
295
+ self.branch3x3dbl_3b(branch3x3dbl),
296
+ ]
297
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
298
+
299
+ # Patch: The FID Inception model uses max pooling instead of average
300
+ # pooling. This is likely an error in this specific Inception
301
+ # implementation, as other Inception models use average pooling here
302
+ # (which matches the description in the paper).
303
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
304
+ branch_pool = self.branch_pool(branch_pool)
305
+
306
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
307
+ return torch.cat(outputs, 1)
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rcan_arch.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+
4
+ from r_basicsr.utils.registry import ARCH_REGISTRY
5
+ from .arch_util import Upsample, make_layer
6
+
7
+
8
+ class ChannelAttention(nn.Module):
9
+ """Channel attention used in RCAN.
10
+
11
+ Args:
12
+ num_feat (int): Channel number of intermediate features.
13
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
14
+ """
15
+
16
+ def __init__(self, num_feat, squeeze_factor=16):
17
+ super(ChannelAttention, self).__init__()
18
+ self.attention = nn.Sequential(
19
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
20
+ nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
21
+
22
+ def forward(self, x):
23
+ y = self.attention(x)
24
+ return x * y
25
+
26
+
27
+ class RCAB(nn.Module):
28
+ """Residual Channel Attention Block (RCAB) used in RCAN.
29
+
30
+ Args:
31
+ num_feat (int): Channel number of intermediate features.
32
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
33
+ res_scale (float): Scale the residual. Default: 1.
34
+ """
35
+
36
+ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
37
+ super(RCAB, self).__init__()
38
+ self.res_scale = res_scale
39
+
40
+ self.rcab = nn.Sequential(
41
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
42
+ ChannelAttention(num_feat, squeeze_factor))
43
+
44
+ def forward(self, x):
45
+ res = self.rcab(x) * self.res_scale
46
+ return res + x
47
+
48
+
49
+ class ResidualGroup(nn.Module):
50
+ """Residual Group of RCAB.
51
+
52
+ Args:
53
+ num_feat (int): Channel number of intermediate features.
54
+ num_block (int): Block number in the body network.
55
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
56
+ res_scale (float): Scale the residual. Default: 1.
57
+ """
58
+
59
+ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
60
+ super(ResidualGroup, self).__init__()
61
+
62
+ self.residual_group = make_layer(
63
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
64
+ self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
65
+
66
+ def forward(self, x):
67
+ res = self.conv(self.residual_group(x))
68
+ return res + x
69
+
70
+
71
+ @ARCH_REGISTRY.register()
72
+ class RCAN(nn.Module):
73
+ """Residual Channel Attention Networks.
74
+
75
+ Paper: Image Super-Resolution Using Very Deep Residual Channel Attention
76
+ Networks
77
+ Ref git repo: https://github.com/yulunzhang/RCAN.
78
+
79
+ Args:
80
+ num_in_ch (int): Channel number of inputs.
81
+ num_out_ch (int): Channel number of outputs.
82
+ num_feat (int): Channel number of intermediate features.
83
+ Default: 64.
84
+ num_group (int): Number of ResidualGroup. Default: 10.
85
+ num_block (int): Number of RCAB in ResidualGroup. Default: 16.
86
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
87
+ upscale (int): Upsampling factor. Support 2^n and 3.
88
+ Default: 4.
89
+ res_scale (float): Used to scale the residual in residual block.
90
+ Default: 1.
91
+ img_range (float): Image range. Default: 255.
92
+ rgb_mean (tuple[float]): Image mean in RGB orders.
93
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
94
+ """
95
+
96
+ def __init__(self,
97
+ num_in_ch,
98
+ num_out_ch,
99
+ num_feat=64,
100
+ num_group=10,
101
+ num_block=16,
102
+ squeeze_factor=16,
103
+ upscale=4,
104
+ res_scale=1,
105
+ img_range=255.,
106
+ rgb_mean=(0.4488, 0.4371, 0.4040)):
107
+ super(RCAN, self).__init__()
108
+
109
+ self.img_range = img_range
110
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
111
+
112
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
113
+ self.body = make_layer(
114
+ ResidualGroup,
115
+ num_group,
116
+ num_feat=num_feat,
117
+ num_block=num_block,
118
+ squeeze_factor=squeeze_factor,
119
+ res_scale=res_scale)
120
+ self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
121
+ self.upsample = Upsample(upscale, num_feat)
122
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
123
+
124
+ def forward(self, x):
125
+ self.mean = self.mean.type_as(x)
126
+
127
+ x = (x - self.mean) * self.img_range
128
+ x = self.conv_first(x)
129
+ res = self.conv_after_body(self.body(x))
130
+ res += x
131
+
132
+ x = self.conv_last(self.upsample(res))
133
+ x = x / self.img_range + self.mean
134
+
135
+ return x
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/ridnet_arch.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from r_basicsr.utils.registry import ARCH_REGISTRY
5
+ from .arch_util import ResidualBlockNoBN, make_layer
6
+
7
+
8
+ class MeanShift(nn.Conv2d):
9
+ """ Data normalization with mean and std.
10
+
11
+ Args:
12
+ rgb_range (int): Maximum value of RGB.
13
+ rgb_mean (list[float]): Mean for RGB channels.
14
+ rgb_std (list[float]): Std for RGB channels.
15
+ sign (int): For subtraction, sign is -1, for addition, sign is 1.
16
+ Default: -1.
17
+ requires_grad (bool): Whether to update the self.weight and self.bias.
18
+ Default: True.
19
+ """
20
+
21
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1, requires_grad=True):
22
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
23
+ std = torch.Tensor(rgb_std)
24
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
25
+ self.weight.data.div_(std.view(3, 1, 1, 1))
26
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
27
+ self.bias.data.div_(std)
28
+ self.requires_grad = requires_grad
29
+
30
+
31
+ class EResidualBlockNoBN(nn.Module):
32
+ """Enhanced Residual block without BN.
33
+
34
+ There are three convolution layers in residual branch.
35
+
36
+ It has a style of:
37
+ ---Conv-ReLU-Conv-ReLU-Conv-+-ReLU-
38
+ |__________________________|
39
+ """
40
+
41
+ def __init__(self, in_channels, out_channels):
42
+ super(EResidualBlockNoBN, self).__init__()
43
+
44
+ self.body = nn.Sequential(
45
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
46
+ nn.ReLU(inplace=True),
47
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
48
+ nn.ReLU(inplace=True),
49
+ nn.Conv2d(out_channels, out_channels, 1, 1, 0),
50
+ )
51
+ self.relu = nn.ReLU(inplace=True)
52
+
53
+ def forward(self, x):
54
+ out = self.body(x)
55
+ out = self.relu(out + x)
56
+ return out
57
+
58
+
59
+ class MergeRun(nn.Module):
60
+ """ Merge-and-run unit.
61
+
62
+ This unit contains two branches with different dilated convolutions,
63
+ followed by a convolution to process the concatenated features.
64
+
65
+ Paper: Real Image Denoising with Feature Attention
66
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
67
+ """
68
+
69
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
70
+ super(MergeRun, self).__init__()
71
+
72
+ self.dilation1 = nn.Sequential(
73
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
74
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
75
+ self.dilation2 = nn.Sequential(
76
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
77
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
78
+
79
+ self.aggregation = nn.Sequential(
80
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
81
+
82
+ def forward(self, x):
83
+ dilation1 = self.dilation1(x)
84
+ dilation2 = self.dilation2(x)
85
+ out = torch.cat([dilation1, dilation2], dim=1)
86
+ out = self.aggregation(out)
87
+ out = out + x
88
+ return out
89
+
90
+
91
+ class ChannelAttention(nn.Module):
92
+ """Channel attention.
93
+
94
+ Args:
95
+ num_feat (int): Channel number of intermediate features.
96
+ squeeze_factor (int): Channel squeeze factor. Default:
97
+ """
98
+
99
+ def __init__(self, mid_channels, squeeze_factor=16):
100
+ super(ChannelAttention, self).__init__()
101
+ self.attention = nn.Sequential(
102
+ nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
103
+ nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
104
+
105
+ def forward(self, x):
106
+ y = self.attention(x)
107
+ return x * y
108
+
109
+
110
+ class EAM(nn.Module):
111
+ """Enhancement attention modules (EAM) in RIDNet.
112
+
113
+ This module contains a merge-and-run unit, a residual block,
114
+ an enhanced residual block and a feature attention unit.
115
+
116
+ Attributes:
117
+ merge: The merge-and-run unit.
118
+ block1: The residual block.
119
+ block2: The enhanced residual block.
120
+ ca: The feature/channel attention unit.
121
+ """
122
+
123
+ def __init__(self, in_channels, mid_channels, out_channels):
124
+ super(EAM, self).__init__()
125
+
126
+ self.merge = MergeRun(in_channels, mid_channels)
127
+ self.block1 = ResidualBlockNoBN(mid_channels)
128
+ self.block2 = EResidualBlockNoBN(mid_channels, out_channels)
129
+ self.ca = ChannelAttention(out_channels)
130
+ # The residual block in the paper contains a relu after addition.
131
+ self.relu = nn.ReLU(inplace=True)
132
+
133
+ def forward(self, x):
134
+ out = self.merge(x)
135
+ out = self.relu(self.block1(out))
136
+ out = self.block2(out)
137
+ out = self.ca(out)
138
+ return out
139
+
140
+
141
+ @ARCH_REGISTRY.register()
142
+ class RIDNet(nn.Module):
143
+ """RIDNet: Real Image Denoising with Feature Attention.
144
+
145
+ Ref git repo: https://github.com/saeed-anwar/RIDNet
146
+
147
+ Args:
148
+ in_channels (int): Channel number of inputs.
149
+ mid_channels (int): Channel number of EAM modules.
150
+ Default: 64.
151
+ out_channels (int): Channel number of outputs.
152
+ num_block (int): Number of EAM. Default: 4.
153
+ img_range (float): Image range. Default: 255.
154
+ rgb_mean (tuple[float]): Image mean in RGB orders.
155
+ Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
156
+ """
157
+
158
+ def __init__(self,
159
+ in_channels,
160
+ mid_channels,
161
+ out_channels,
162
+ num_block=4,
163
+ img_range=255.,
164
+ rgb_mean=(0.4488, 0.4371, 0.4040),
165
+ rgb_std=(1.0, 1.0, 1.0)):
166
+ super(RIDNet, self).__init__()
167
+
168
+ self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
169
+ self.add_mean = MeanShift(img_range, rgb_mean, rgb_std, 1)
170
+
171
+ self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
172
+ self.body = make_layer(
173
+ EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
174
+ self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
175
+
176
+ self.relu = nn.ReLU(inplace=True)
177
+
178
+ def forward(self, x):
179
+ res = self.sub_mean(x)
180
+ res = self.tail(self.body(self.relu(self.head(res))))
181
+ res = self.add_mean(res)
182
+
183
+ out = x + res
184
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/rrdbnet_arch.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import default_init_weights, make_layer, pixel_unshuffle
7
+
8
+
9
+ class ResidualDenseBlock(nn.Module):
10
+ """Residual Dense Block.
11
+
12
+ Used in RRDB block in ESRGAN.
13
+
14
+ Args:
15
+ num_feat (int): Channel number of intermediate features.
16
+ num_grow_ch (int): Channels for each growth.
17
+ """
18
+
19
+ def __init__(self, num_feat=64, num_grow_ch=32):
20
+ super(ResidualDenseBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
22
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
25
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
26
+
27
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
28
+
29
+ # initialization
30
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
31
+
32
+ def forward(self, x):
33
+ x1 = self.lrelu(self.conv1(x))
34
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
35
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
36
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
37
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
38
+ # Empirically, we use 0.2 to scale the residual for better performance
39
+ return x5 * 0.2 + x
40
+
41
+
42
+ class RRDB(nn.Module):
43
+ """Residual in Residual Dense Block.
44
+
45
+ Used in RRDB-Net in ESRGAN.
46
+
47
+ Args:
48
+ num_feat (int): Channel number of intermediate features.
49
+ num_grow_ch (int): Channels for each growth.
50
+ """
51
+
52
+ def __init__(self, num_feat, num_grow_ch=32):
53
+ super(RRDB, self).__init__()
54
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
57
+
58
+ def forward(self, x):
59
+ out = self.rdb1(x)
60
+ out = self.rdb2(out)
61
+ out = self.rdb3(out)
62
+ # Empirically, we use 0.2 to scale the residual for better performance
63
+ return out * 0.2 + x
64
+
65
+
66
+ @ARCH_REGISTRY.register()
67
+ class RRDBNet(nn.Module):
68
+ """Networks consisting of Residual in Residual Dense Block, which is used
69
+ in ESRGAN.
70
+
71
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
72
+
73
+ We extend ESRGAN for scale x2 and scale x1.
74
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
75
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
76
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
77
+
78
+ Args:
79
+ num_in_ch (int): Channel number of inputs.
80
+ num_out_ch (int): Channel number of outputs.
81
+ num_feat (int): Channel number of intermediate features.
82
+ Default: 64
83
+ num_block (int): Block number in the trunk network. Defaults: 23
84
+ num_grow_ch (int): Channels for each growth. Default: 32.
85
+ """
86
+
87
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
88
+ super(RRDBNet, self).__init__()
89
+ self.scale = scale
90
+ if scale == 2:
91
+ num_in_ch = num_in_ch * 4
92
+ elif scale == 1:
93
+ num_in_ch = num_in_ch * 16
94
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
95
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
96
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ # upsample
98
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
99
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
100
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
101
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
102
+
103
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
104
+
105
+ def forward(self, x):
106
+ if self.scale == 2:
107
+ feat = pixel_unshuffle(x, scale=2)
108
+ elif self.scale == 1:
109
+ feat = pixel_unshuffle(x, scale=4)
110
+ else:
111
+ feat = x
112
+ feat = self.conv_first(feat)
113
+ body_feat = self.conv_body(self.body(feat))
114
+ feat = feat + body_feat
115
+ # upsample
116
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
117
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
118
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
119
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/spynet_arch.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+
6
+ from r_basicsr.utils.registry import ARCH_REGISTRY
7
+ from .arch_util import flow_warp
8
+
9
+
10
+ class BasicModule(nn.Module):
11
+ """Basic Module for SpyNet.
12
+ """
13
+
14
+ def __init__(self):
15
+ super(BasicModule, self).__init__()
16
+
17
+ self.basic_module = nn.Sequential(
18
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
19
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
20
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
21
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
22
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
23
+
24
+ def forward(self, tensor_input):
25
+ return self.basic_module(tensor_input)
26
+
27
+
28
+ @ARCH_REGISTRY.register()
29
+ class SpyNet(nn.Module):
30
+ """SpyNet architecture.
31
+
32
+ Args:
33
+ load_path (str): path for pretrained SpyNet. Default: None.
34
+ """
35
+
36
+ def __init__(self, load_path=None):
37
+ super(SpyNet, self).__init__()
38
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(6)])
39
+ if load_path:
40
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
41
+
42
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
43
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
44
+
45
+ def preprocess(self, tensor_input):
46
+ tensor_output = (tensor_input - self.mean) / self.std
47
+ return tensor_output
48
+
49
+ def process(self, ref, supp):
50
+ flow = []
51
+
52
+ ref = [self.preprocess(ref)]
53
+ supp = [self.preprocess(supp)]
54
+
55
+ for level in range(5):
56
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
57
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
58
+
59
+ flow = ref[0].new_zeros(
60
+ [ref[0].size(0), 2,
61
+ int(math.floor(ref[0].size(2) / 2.0)),
62
+ int(math.floor(ref[0].size(3) / 2.0))])
63
+
64
+ for level in range(len(ref)):
65
+ upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
66
+
67
+ if upsampled_flow.size(2) != ref[level].size(2):
68
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 0, 0, 1], mode='replicate')
69
+ if upsampled_flow.size(3) != ref[level].size(3):
70
+ upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
71
+
72
+ flow = self.basic_module[level](torch.cat([
73
+ ref[level],
74
+ flow_warp(
75
+ supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
76
+ upsampled_flow
77
+ ], 1)) + upsampled_flow
78
+
79
+ return flow
80
+
81
+ def forward(self, ref, supp):
82
+ assert ref.size() == supp.size()
83
+
84
+ h, w = ref.size(2), ref.size(3)
85
+ w_floor = math.floor(math.ceil(w / 32.0) * 32.0)
86
+ h_floor = math.floor(math.ceil(h / 32.0) * 32.0)
87
+
88
+ ref = F.interpolate(input=ref, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
89
+ supp = F.interpolate(input=supp, size=(h_floor, w_floor), mode='bilinear', align_corners=False)
90
+
91
+ flow = F.interpolate(input=self.process(ref, supp), size=(h, w), mode='bilinear', align_corners=False)
92
+
93
+ flow[:, 0, :, :] *= float(w) / float(w_floor)
94
+ flow[:, 1, :, :] *= float(h) / float(h_floor)
95
+
96
+ return flow
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srresnet_arch.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+ from r_basicsr.utils.registry import ARCH_REGISTRY
5
+ from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
6
+
7
+
8
+ @ARCH_REGISTRY.register()
9
+ class MSRResNet(nn.Module):
10
+ """Modified SRResNet.
11
+
12
+ A compacted version modified from SRResNet in
13
+ "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"
14
+ It uses residual blocks without BN, similar to EDSR.
15
+ Currently, it supports x2, x3 and x4 upsampling scale factor.
16
+
17
+ Args:
18
+ num_in_ch (int): Channel number of inputs. Default: 3.
19
+ num_out_ch (int): Channel number of outputs. Default: 3.
20
+ num_feat (int): Channel number of intermediate features. Default: 64.
21
+ num_block (int): Block number in the body network. Default: 16.
22
+ upscale (int): Upsampling factor. Support x2, x3 and x4. Default: 4.
23
+ """
24
+
25
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_block=16, upscale=4):
26
+ super(MSRResNet, self).__init__()
27
+ self.upscale = upscale
28
+
29
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
30
+ self.body = make_layer(ResidualBlockNoBN, num_block, num_feat=num_feat)
31
+
32
+ # upsampling
33
+ if self.upscale in [2, 3]:
34
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * self.upscale * self.upscale, 3, 1, 1)
35
+ self.pixel_shuffle = nn.PixelShuffle(self.upscale)
36
+ elif self.upscale == 4:
37
+ self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
38
+ self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
39
+ self.pixel_shuffle = nn.PixelShuffle(2)
40
+
41
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
42
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
43
+
44
+ # activation function
45
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
46
+
47
+ # initialization
48
+ default_init_weights([self.conv_first, self.upconv1, self.conv_hr, self.conv_last], 0.1)
49
+ if self.upscale == 4:
50
+ default_init_weights(self.upconv2, 0.1)
51
+
52
+ def forward(self, x):
53
+ feat = self.lrelu(self.conv_first(x))
54
+ out = self.body(feat)
55
+
56
+ if self.upscale == 4:
57
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
58
+ out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
59
+ elif self.upscale in [2, 3]:
60
+ out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
61
+
62
+ out = self.conv_last(self.lrelu(self.conv_hr(out)))
63
+ base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
64
+ out += base
65
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/srvgg_arch.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn as nn
2
+ from torch.nn import functional as F
3
+
4
+ from r_basicsr.utils.registry import ARCH_REGISTRY
5
+
6
+
7
+ @ARCH_REGISTRY.register(suffix='basicsr')
8
+ class SRVGGNetCompact(nn.Module):
9
+ """A compact VGG-style network structure for super-resolution.
10
+
11
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
12
+ conducted on the HR feature space.
13
+
14
+ Args:
15
+ num_in_ch (int): Channel number of inputs. Default: 3.
16
+ num_out_ch (int): Channel number of outputs. Default: 3.
17
+ num_feat (int): Channel number of intermediate features. Default: 64.
18
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
19
+ upscale (int): Upsampling factor. Default: 4.
20
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
21
+ """
22
+
23
+ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'):
24
+ super(SRVGGNetCompact, self).__init__()
25
+ self.num_in_ch = num_in_ch
26
+ self.num_out_ch = num_out_ch
27
+ self.num_feat = num_feat
28
+ self.num_conv = num_conv
29
+ self.upscale = upscale
30
+ self.act_type = act_type
31
+
32
+ self.body = nn.ModuleList()
33
+ # the first conv
34
+ self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
35
+ # the first activation
36
+ if act_type == 'relu':
37
+ activation = nn.ReLU(inplace=True)
38
+ elif act_type == 'prelu':
39
+ activation = nn.PReLU(num_parameters=num_feat)
40
+ elif act_type == 'leakyrelu':
41
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
42
+ self.body.append(activation)
43
+
44
+ # the body structure
45
+ for _ in range(num_conv):
46
+ self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
47
+ # activation
48
+ if act_type == 'relu':
49
+ activation = nn.ReLU(inplace=True)
50
+ elif act_type == 'prelu':
51
+ activation = nn.PReLU(num_parameters=num_feat)
52
+ elif act_type == 'leakyrelu':
53
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
54
+ self.body.append(activation)
55
+
56
+ # the last conv
57
+ self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1))
58
+ # upsample
59
+ self.upsampler = nn.PixelShuffle(upscale)
60
+
61
+ def forward(self, x):
62
+ out = x
63
+ for i in range(0, len(self.body)):
64
+ out = self.body[i](out)
65
+
66
+ out = self.upsampler(out)
67
+ # add the nearest upsampled image, so that the network learns the residual
68
+ base = F.interpolate(x, scale_factor=self.upscale, mode='nearest')
69
+ out += base
70
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/stylegan2_arch.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from r_basicsr.ops.fused_act import FusedLeakyReLU, fused_leaky_relu
8
+ from r_basicsr.ops.upfirdn2d import upfirdn2d
9
+ from r_basicsr.utils.registry import ARCH_REGISTRY
10
+
11
+
12
+ class NormStyleCode(nn.Module):
13
+
14
+ def forward(self, x):
15
+ """Normalize the style codes.
16
+
17
+ Args:
18
+ x (Tensor): Style codes with shape (b, c).
19
+
20
+ Returns:
21
+ Tensor: Normalized tensor.
22
+ """
23
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
24
+
25
+
26
+ def make_resample_kernel(k):
27
+ """Make resampling kernel for UpFirDn.
28
+
29
+ Args:
30
+ k (list[int]): A list indicating the 1D resample kernel magnitude.
31
+
32
+ Returns:
33
+ Tensor: 2D resampled kernel.
34
+ """
35
+ k = torch.tensor(k, dtype=torch.float32)
36
+ if k.ndim == 1:
37
+ k = k[None, :] * k[:, None] # to 2D kernel, outer product
38
+ # normalize
39
+ k /= k.sum()
40
+ return k
41
+
42
+
43
+ class UpFirDnUpsample(nn.Module):
44
+ """Upsample, FIR filter, and downsample (upsampole version).
45
+
46
+ References:
47
+ 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
48
+ 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
49
+
50
+ Args:
51
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
52
+ magnitude.
53
+ factor (int): Upsampling scale factor. Default: 2.
54
+ """
55
+
56
+ def __init__(self, resample_kernel, factor=2):
57
+ super(UpFirDnUpsample, self).__init__()
58
+ self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
59
+ self.factor = factor
60
+
61
+ pad = self.kernel.shape[0] - factor
62
+ self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
63
+
64
+ def forward(self, x):
65
+ out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
66
+ return out
67
+
68
+ def __repr__(self):
69
+ return (f'{self.__class__.__name__}(factor={self.factor})')
70
+
71
+
72
+ class UpFirDnDownsample(nn.Module):
73
+ """Upsample, FIR filter, and downsample (downsampole version).
74
+
75
+ Args:
76
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
77
+ magnitude.
78
+ factor (int): Downsampling scale factor. Default: 2.
79
+ """
80
+
81
+ def __init__(self, resample_kernel, factor=2):
82
+ super(UpFirDnDownsample, self).__init__()
83
+ self.kernel = make_resample_kernel(resample_kernel)
84
+ self.factor = factor
85
+
86
+ pad = self.kernel.shape[0] - factor
87
+ self.pad = ((pad + 1) // 2, pad // 2)
88
+
89
+ def forward(self, x):
90
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
91
+ return out
92
+
93
+ def __repr__(self):
94
+ return (f'{self.__class__.__name__}(factor={self.factor})')
95
+
96
+
97
+ class UpFirDnSmooth(nn.Module):
98
+ """Upsample, FIR filter, and downsample (smooth version).
99
+
100
+ Args:
101
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
102
+ magnitude.
103
+ upsample_factor (int): Upsampling scale factor. Default: 1.
104
+ downsample_factor (int): Downsampling scale factor. Default: 1.
105
+ kernel_size (int): Kernel size: Default: 1.
106
+ """
107
+
108
+ def __init__(self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1):
109
+ super(UpFirDnSmooth, self).__init__()
110
+ self.upsample_factor = upsample_factor
111
+ self.downsample_factor = downsample_factor
112
+ self.kernel = make_resample_kernel(resample_kernel)
113
+ if upsample_factor > 1:
114
+ self.kernel = self.kernel * (upsample_factor**2)
115
+
116
+ if upsample_factor > 1:
117
+ pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
118
+ self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
119
+ elif downsample_factor > 1:
120
+ pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
121
+ self.pad = ((pad + 1) // 2, pad // 2)
122
+ else:
123
+ raise NotImplementedError
124
+
125
+ def forward(self, x):
126
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
127
+ return out
128
+
129
+ def __repr__(self):
130
+ return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
131
+ f', downsample_factor={self.downsample_factor})')
132
+
133
+
134
+ class EqualLinear(nn.Module):
135
+ """Equalized Linear as StyleGAN2.
136
+
137
+ Args:
138
+ in_channels (int): Size of each sample.
139
+ out_channels (int): Size of each output sample.
140
+ bias (bool): If set to ``False``, the layer will not learn an additive
141
+ bias. Default: ``True``.
142
+ bias_init_val (float): Bias initialized value. Default: 0.
143
+ lr_mul (float): Learning rate multiplier. Default: 1.
144
+ activation (None | str): The activation after ``linear`` operation.
145
+ Supported: 'fused_lrelu', None. Default: None.
146
+ """
147
+
148
+ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul=1, activation=None):
149
+ super(EqualLinear, self).__init__()
150
+ self.in_channels = in_channels
151
+ self.out_channels = out_channels
152
+ self.lr_mul = lr_mul
153
+ self.activation = activation
154
+ if self.activation not in ['fused_lrelu', None]:
155
+ raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
156
+ "Supported ones are: ['fused_lrelu', None].")
157
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
158
+
159
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
160
+ if bias:
161
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
162
+ else:
163
+ self.register_parameter('bias', None)
164
+
165
+ def forward(self, x):
166
+ if self.bias is None:
167
+ bias = None
168
+ else:
169
+ bias = self.bias * self.lr_mul
170
+ if self.activation == 'fused_lrelu':
171
+ out = F.linear(x, self.weight * self.scale)
172
+ out = fused_leaky_relu(out, bias)
173
+ else:
174
+ out = F.linear(x, self.weight * self.scale, bias=bias)
175
+ return out
176
+
177
+ def __repr__(self):
178
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
179
+ f'out_channels={self.out_channels}, bias={self.bias is not None})')
180
+
181
+
182
+ class ModulatedConv2d(nn.Module):
183
+ """Modulated Conv2d used in StyleGAN2.
184
+
185
+ There is no bias in ModulatedConv2d.
186
+
187
+ Args:
188
+ in_channels (int): Channel number of the input.
189
+ out_channels (int): Channel number of the output.
190
+ kernel_size (int): Size of the convolving kernel.
191
+ num_style_feat (int): Channel number of style features.
192
+ demodulate (bool): Whether to demodulate in the conv layer.
193
+ Default: True.
194
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
195
+ Default: None.
196
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
197
+ magnitude. Default: (1, 3, 3, 1).
198
+ eps (float): A value added to the denominator for numerical stability.
199
+ Default: 1e-8.
200
+ """
201
+
202
+ def __init__(self,
203
+ in_channels,
204
+ out_channels,
205
+ kernel_size,
206
+ num_style_feat,
207
+ demodulate=True,
208
+ sample_mode=None,
209
+ resample_kernel=(1, 3, 3, 1),
210
+ eps=1e-8):
211
+ super(ModulatedConv2d, self).__init__()
212
+ self.in_channels = in_channels
213
+ self.out_channels = out_channels
214
+ self.kernel_size = kernel_size
215
+ self.demodulate = demodulate
216
+ self.sample_mode = sample_mode
217
+ self.eps = eps
218
+
219
+ if self.sample_mode == 'upsample':
220
+ self.smooth = UpFirDnSmooth(
221
+ resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
222
+ elif self.sample_mode == 'downsample':
223
+ self.smooth = UpFirDnSmooth(
224
+ resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
225
+ elif self.sample_mode is None:
226
+ pass
227
+ else:
228
+ raise ValueError(f'Wrong sample mode {self.sample_mode}, '
229
+ "supported ones are ['upsample', 'downsample', None].")
230
+
231
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
232
+ # modulation inside each modulated conv
233
+ self.modulation = EqualLinear(
234
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
235
+
236
+ self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
237
+ self.padding = kernel_size // 2
238
+
239
+ def forward(self, x, style):
240
+ """Forward function.
241
+
242
+ Args:
243
+ x (Tensor): Tensor with shape (b, c, h, w).
244
+ style (Tensor): Tensor with shape (b, num_style_feat).
245
+
246
+ Returns:
247
+ Tensor: Modulated tensor after convolution.
248
+ """
249
+ b, c, h, w = x.shape # c = c_in
250
+ # weight modulation
251
+ style = self.modulation(style).view(b, 1, c, 1, 1)
252
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
253
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
254
+
255
+ if self.demodulate:
256
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
257
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
258
+
259
+ weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size)
260
+
261
+ if self.sample_mode == 'upsample':
262
+ x = x.view(1, b * c, h, w)
263
+ weight = weight.view(b, self.out_channels, c, self.kernel_size, self.kernel_size)
264
+ weight = weight.transpose(1, 2).reshape(b * c, self.out_channels, self.kernel_size, self.kernel_size)
265
+ out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
266
+ out = out.view(b, self.out_channels, *out.shape[2:4])
267
+ out = self.smooth(out)
268
+ elif self.sample_mode == 'downsample':
269
+ x = self.smooth(x)
270
+ x = x.view(1, b * c, *x.shape[2:4])
271
+ out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
272
+ out = out.view(b, self.out_channels, *out.shape[2:4])
273
+ else:
274
+ x = x.view(1, b * c, h, w)
275
+ # weight: (b*c_out, c_in, k, k), groups=b
276
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
277
+ out = out.view(b, self.out_channels, *out.shape[2:4])
278
+
279
+ return out
280
+
281
+ def __repr__(self):
282
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
283
+ f'out_channels={self.out_channels}, '
284
+ f'kernel_size={self.kernel_size}, '
285
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
286
+
287
+
288
+ class StyleConv(nn.Module):
289
+ """Style conv.
290
+
291
+ Args:
292
+ in_channels (int): Channel number of the input.
293
+ out_channels (int): Channel number of the output.
294
+ kernel_size (int): Size of the convolving kernel.
295
+ num_style_feat (int): Channel number of style features.
296
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
297
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
298
+ Default: None.
299
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
300
+ magnitude. Default: (1, 3, 3, 1).
301
+ """
302
+
303
+ def __init__(self,
304
+ in_channels,
305
+ out_channels,
306
+ kernel_size,
307
+ num_style_feat,
308
+ demodulate=True,
309
+ sample_mode=None,
310
+ resample_kernel=(1, 3, 3, 1)):
311
+ super(StyleConv, self).__init__()
312
+ self.modulated_conv = ModulatedConv2d(
313
+ in_channels,
314
+ out_channels,
315
+ kernel_size,
316
+ num_style_feat,
317
+ demodulate=demodulate,
318
+ sample_mode=sample_mode,
319
+ resample_kernel=resample_kernel)
320
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
321
+ self.activate = FusedLeakyReLU(out_channels)
322
+
323
+ def forward(self, x, style, noise=None):
324
+ # modulate
325
+ out = self.modulated_conv(x, style)
326
+ # noise injection
327
+ if noise is None:
328
+ b, _, h, w = out.shape
329
+ noise = out.new_empty(b, 1, h, w).normal_()
330
+ out = out + self.weight * noise
331
+ # activation (with bias)
332
+ out = self.activate(out)
333
+ return out
334
+
335
+
336
+ class ToRGB(nn.Module):
337
+ """To RGB from features.
338
+
339
+ Args:
340
+ in_channels (int): Channel number of input.
341
+ num_style_feat (int): Channel number of style features.
342
+ upsample (bool): Whether to upsample. Default: True.
343
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
344
+ magnitude. Default: (1, 3, 3, 1).
345
+ """
346
+
347
+ def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)):
348
+ super(ToRGB, self).__init__()
349
+ if upsample:
350
+ self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
351
+ else:
352
+ self.upsample = None
353
+ self.modulated_conv = ModulatedConv2d(
354
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
355
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
356
+
357
+ def forward(self, x, style, skip=None):
358
+ """Forward function.
359
+
360
+ Args:
361
+ x (Tensor): Feature tensor with shape (b, c, h, w).
362
+ style (Tensor): Tensor with shape (b, num_style_feat).
363
+ skip (Tensor): Base/skip tensor. Default: None.
364
+
365
+ Returns:
366
+ Tensor: RGB images.
367
+ """
368
+ out = self.modulated_conv(x, style)
369
+ out = out + self.bias
370
+ if skip is not None:
371
+ if self.upsample:
372
+ skip = self.upsample(skip)
373
+ out = out + skip
374
+ return out
375
+
376
+
377
+ class ConstantInput(nn.Module):
378
+ """Constant input.
379
+
380
+ Args:
381
+ num_channel (int): Channel number of constant input.
382
+ size (int): Spatial size of constant input.
383
+ """
384
+
385
+ def __init__(self, num_channel, size):
386
+ super(ConstantInput, self).__init__()
387
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
388
+
389
+ def forward(self, batch):
390
+ out = self.weight.repeat(batch, 1, 1, 1)
391
+ return out
392
+
393
+
394
+ @ARCH_REGISTRY.register()
395
+ class StyleGAN2Generator(nn.Module):
396
+ """StyleGAN2 Generator.
397
+
398
+ Args:
399
+ out_size (int): The spatial size of outputs.
400
+ num_style_feat (int): Channel number of style features. Default: 512.
401
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
402
+ channel_multiplier (int): Channel multiplier for large networks of
403
+ StyleGAN2. Default: 2.
404
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
405
+ magnitude. A cross production will be applied to extent 1D resample
406
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
407
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
408
+ narrow (float): Narrow ratio for channels. Default: 1.0.
409
+ """
410
+
411
+ def __init__(self,
412
+ out_size,
413
+ num_style_feat=512,
414
+ num_mlp=8,
415
+ channel_multiplier=2,
416
+ resample_kernel=(1, 3, 3, 1),
417
+ lr_mlp=0.01,
418
+ narrow=1):
419
+ super(StyleGAN2Generator, self).__init__()
420
+ # Style MLP layers
421
+ self.num_style_feat = num_style_feat
422
+ style_mlp_layers = [NormStyleCode()]
423
+ for i in range(num_mlp):
424
+ style_mlp_layers.append(
425
+ EqualLinear(
426
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
427
+ activation='fused_lrelu'))
428
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
429
+
430
+ channels = {
431
+ '4': int(512 * narrow),
432
+ '8': int(512 * narrow),
433
+ '16': int(512 * narrow),
434
+ '32': int(512 * narrow),
435
+ '64': int(256 * channel_multiplier * narrow),
436
+ '128': int(128 * channel_multiplier * narrow),
437
+ '256': int(64 * channel_multiplier * narrow),
438
+ '512': int(32 * channel_multiplier * narrow),
439
+ '1024': int(16 * channel_multiplier * narrow)
440
+ }
441
+ self.channels = channels
442
+
443
+ self.constant_input = ConstantInput(channels['4'], size=4)
444
+ self.style_conv1 = StyleConv(
445
+ channels['4'],
446
+ channels['4'],
447
+ kernel_size=3,
448
+ num_style_feat=num_style_feat,
449
+ demodulate=True,
450
+ sample_mode=None,
451
+ resample_kernel=resample_kernel)
452
+ self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
453
+
454
+ self.log_size = int(math.log(out_size, 2))
455
+ self.num_layers = (self.log_size - 2) * 2 + 1
456
+ self.num_latent = self.log_size * 2 - 2
457
+
458
+ self.style_convs = nn.ModuleList()
459
+ self.to_rgbs = nn.ModuleList()
460
+ self.noises = nn.Module()
461
+
462
+ in_channels = channels['4']
463
+ # noise
464
+ for layer_idx in range(self.num_layers):
465
+ resolution = 2**((layer_idx + 5) // 2)
466
+ shape = [1, 1, resolution, resolution]
467
+ self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
468
+ # style convs and to_rgbs
469
+ for i in range(3, self.log_size + 1):
470
+ out_channels = channels[f'{2**i}']
471
+ self.style_convs.append(
472
+ StyleConv(
473
+ in_channels,
474
+ out_channels,
475
+ kernel_size=3,
476
+ num_style_feat=num_style_feat,
477
+ demodulate=True,
478
+ sample_mode='upsample',
479
+ resample_kernel=resample_kernel,
480
+ ))
481
+ self.style_convs.append(
482
+ StyleConv(
483
+ out_channels,
484
+ out_channels,
485
+ kernel_size=3,
486
+ num_style_feat=num_style_feat,
487
+ demodulate=True,
488
+ sample_mode=None,
489
+ resample_kernel=resample_kernel))
490
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
491
+ in_channels = out_channels
492
+
493
+ def make_noise(self):
494
+ """Make noise for noise injection."""
495
+ device = self.constant_input.weight.device
496
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
497
+
498
+ for i in range(3, self.log_size + 1):
499
+ for _ in range(2):
500
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
501
+
502
+ return noises
503
+
504
+ def get_latent(self, x):
505
+ return self.style_mlp(x)
506
+
507
+ def mean_latent(self, num_latent):
508
+ latent_in = torch.randn(num_latent, self.num_style_feat, device=self.constant_input.weight.device)
509
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
510
+ return latent
511
+
512
+ def forward(self,
513
+ styles,
514
+ input_is_latent=False,
515
+ noise=None,
516
+ randomize_noise=True,
517
+ truncation=1,
518
+ truncation_latent=None,
519
+ inject_index=None,
520
+ return_latents=False):
521
+ """Forward function for StyleGAN2Generator.
522
+
523
+ Args:
524
+ styles (list[Tensor]): Sample codes of styles.
525
+ input_is_latent (bool): Whether input is latent style.
526
+ Default: False.
527
+ noise (Tensor | None): Input noise or None. Default: None.
528
+ randomize_noise (bool): Randomize noise, used when 'noise' is
529
+ False. Default: True.
530
+ truncation (float): TODO. Default: 1.
531
+ truncation_latent (Tensor | None): TODO. Default: None.
532
+ inject_index (int | None): The injection index for mixing noise.
533
+ Default: None.
534
+ return_latents (bool): Whether to return style latents.
535
+ Default: False.
536
+ """
537
+ # style codes -> latents with Style MLP layer
538
+ if not input_is_latent:
539
+ styles = [self.style_mlp(s) for s in styles]
540
+ # noises
541
+ if noise is None:
542
+ if randomize_noise:
543
+ noise = [None] * self.num_layers # for each style conv layer
544
+ else: # use the stored noise
545
+ noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)]
546
+ # style truncation
547
+ if truncation < 1:
548
+ style_truncation = []
549
+ for style in styles:
550
+ style_truncation.append(truncation_latent + truncation * (style - truncation_latent))
551
+ styles = style_truncation
552
+ # get style latent with injection
553
+ if len(styles) == 1:
554
+ inject_index = self.num_latent
555
+
556
+ if styles[0].ndim < 3:
557
+ # repeat latent code for all the layers
558
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
559
+ else: # used for encoder with different latent code for each layer
560
+ latent = styles[0]
561
+ elif len(styles) == 2: # mixing noises
562
+ if inject_index is None:
563
+ inject_index = random.randint(1, self.num_latent - 1)
564
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
565
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
566
+ latent = torch.cat([latent1, latent2], 1)
567
+
568
+ # main generation
569
+ out = self.constant_input(latent.shape[0])
570
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
571
+ skip = self.to_rgb1(out, latent[:, 1])
572
+
573
+ i = 1
574
+ for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
575
+ noise[2::2], self.to_rgbs):
576
+ out = conv1(out, latent[:, i], noise=noise1)
577
+ out = conv2(out, latent[:, i + 1], noise=noise2)
578
+ skip = to_rgb(out, latent[:, i + 2], skip)
579
+ i += 2
580
+
581
+ image = skip
582
+
583
+ if return_latents:
584
+ return image, latent
585
+ else:
586
+ return image, None
587
+
588
+
589
+ class ScaledLeakyReLU(nn.Module):
590
+ """Scaled LeakyReLU.
591
+
592
+ Args:
593
+ negative_slope (float): Negative slope. Default: 0.2.
594
+ """
595
+
596
+ def __init__(self, negative_slope=0.2):
597
+ super(ScaledLeakyReLU, self).__init__()
598
+ self.negative_slope = negative_slope
599
+
600
+ def forward(self, x):
601
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
602
+ return out * math.sqrt(2)
603
+
604
+
605
+ class EqualConv2d(nn.Module):
606
+ """Equalized Linear as StyleGAN2.
607
+
608
+ Args:
609
+ in_channels (int): Channel number of the input.
610
+ out_channels (int): Channel number of the output.
611
+ kernel_size (int): Size of the convolving kernel.
612
+ stride (int): Stride of the convolution. Default: 1
613
+ padding (int): Zero-padding added to both sides of the input.
614
+ Default: 0.
615
+ bias (bool): If ``True``, adds a learnable bias to the output.
616
+ Default: ``True``.
617
+ bias_init_val (float): Bias initialized value. Default: 0.
618
+ """
619
+
620
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True, bias_init_val=0):
621
+ super(EqualConv2d, self).__init__()
622
+ self.in_channels = in_channels
623
+ self.out_channels = out_channels
624
+ self.kernel_size = kernel_size
625
+ self.stride = stride
626
+ self.padding = padding
627
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
628
+
629
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
630
+ if bias:
631
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
632
+ else:
633
+ self.register_parameter('bias', None)
634
+
635
+ def forward(self, x):
636
+ out = F.conv2d(
637
+ x,
638
+ self.weight * self.scale,
639
+ bias=self.bias,
640
+ stride=self.stride,
641
+ padding=self.padding,
642
+ )
643
+
644
+ return out
645
+
646
+ def __repr__(self):
647
+ return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
648
+ f'out_channels={self.out_channels}, '
649
+ f'kernel_size={self.kernel_size},'
650
+ f' stride={self.stride}, padding={self.padding}, '
651
+ f'bias={self.bias is not None})')
652
+
653
+
654
+ class ConvLayer(nn.Sequential):
655
+ """Conv Layer used in StyleGAN2 Discriminator.
656
+
657
+ Args:
658
+ in_channels (int): Channel number of the input.
659
+ out_channels (int): Channel number of the output.
660
+ kernel_size (int): Kernel size.
661
+ downsample (bool): Whether downsample by a factor of 2.
662
+ Default: False.
663
+ resample_kernel (list[int]): A list indicating the 1D resample
664
+ kernel magnitude. A cross production will be applied to
665
+ extent 1D resample kernel to 2D resample kernel.
666
+ Default: (1, 3, 3, 1).
667
+ bias (bool): Whether with bias. Default: True.
668
+ activate (bool): Whether use activateion. Default: True.
669
+ """
670
+
671
+ def __init__(self,
672
+ in_channels,
673
+ out_channels,
674
+ kernel_size,
675
+ downsample=False,
676
+ resample_kernel=(1, 3, 3, 1),
677
+ bias=True,
678
+ activate=True):
679
+ layers = []
680
+ # downsample
681
+ if downsample:
682
+ layers.append(
683
+ UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
684
+ stride = 2
685
+ self.padding = 0
686
+ else:
687
+ stride = 1
688
+ self.padding = kernel_size // 2
689
+ # conv
690
+ layers.append(
691
+ EqualConv2d(
692
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
693
+ and not activate))
694
+ # activation
695
+ if activate:
696
+ if bias:
697
+ layers.append(FusedLeakyReLU(out_channels))
698
+ else:
699
+ layers.append(ScaledLeakyReLU(0.2))
700
+
701
+ super(ConvLayer, self).__init__(*layers)
702
+
703
+
704
+ class ResBlock(nn.Module):
705
+ """Residual block used in StyleGAN2 Discriminator.
706
+
707
+ Args:
708
+ in_channels (int): Channel number of the input.
709
+ out_channels (int): Channel number of the output.
710
+ resample_kernel (list[int]): A list indicating the 1D resample
711
+ kernel magnitude. A cross production will be applied to
712
+ extent 1D resample kernel to 2D resample kernel.
713
+ Default: (1, 3, 3, 1).
714
+ """
715
+
716
+ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
717
+ super(ResBlock, self).__init__()
718
+
719
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
720
+ self.conv2 = ConvLayer(
721
+ in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
722
+ self.skip = ConvLayer(
723
+ in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
724
+
725
+ def forward(self, x):
726
+ out = self.conv1(x)
727
+ out = self.conv2(out)
728
+ skip = self.skip(x)
729
+ out = (out + skip) / math.sqrt(2)
730
+ return out
731
+
732
+
733
+ @ARCH_REGISTRY.register()
734
+ class StyleGAN2Discriminator(nn.Module):
735
+ """StyleGAN2 Discriminator.
736
+
737
+ Args:
738
+ out_size (int): The spatial size of outputs.
739
+ channel_multiplier (int): Channel multiplier for large networks of
740
+ StyleGAN2. Default: 2.
741
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
742
+ magnitude. A cross production will be applied to extent 1D resample
743
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
744
+ stddev_group (int): For group stddev statistics. Default: 4.
745
+ narrow (float): Narrow ratio for channels. Default: 1.0.
746
+ """
747
+
748
+ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1), stddev_group=4, narrow=1):
749
+ super(StyleGAN2Discriminator, self).__init__()
750
+
751
+ channels = {
752
+ '4': int(512 * narrow),
753
+ '8': int(512 * narrow),
754
+ '16': int(512 * narrow),
755
+ '32': int(512 * narrow),
756
+ '64': int(256 * channel_multiplier * narrow),
757
+ '128': int(128 * channel_multiplier * narrow),
758
+ '256': int(64 * channel_multiplier * narrow),
759
+ '512': int(32 * channel_multiplier * narrow),
760
+ '1024': int(16 * channel_multiplier * narrow)
761
+ }
762
+
763
+ log_size = int(math.log(out_size, 2))
764
+
765
+ conv_body = [ConvLayer(3, channels[f'{out_size}'], 1, bias=True, activate=True)]
766
+
767
+ in_channels = channels[f'{out_size}']
768
+ for i in range(log_size, 2, -1):
769
+ out_channels = channels[f'{2**(i - 1)}']
770
+ conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
771
+ in_channels = out_channels
772
+ self.conv_body = nn.Sequential(*conv_body)
773
+
774
+ self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
775
+ self.final_linear = nn.Sequential(
776
+ EqualLinear(
777
+ channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
778
+ EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
779
+ )
780
+ self.stddev_group = stddev_group
781
+ self.stddev_feat = 1
782
+
783
+ def forward(self, x):
784
+ out = self.conv_body(x)
785
+
786
+ b, c, h, w = out.shape
787
+ # concatenate a group stddev statistics to out
788
+ group = min(b, self.stddev_group) # Minibatch must be divisible by (or smaller than) group_size
789
+ stddev = out.view(group, -1, self.stddev_feat, c // self.stddev_feat, h, w)
790
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
791
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
792
+ stddev = stddev.repeat(group, 1, h, w)
793
+ out = torch.cat([out, stddev], 1)
794
+
795
+ out = self.final_conv(out)
796
+ out = out.view(b, -1)
797
+ out = self.final_linear(out)
798
+
799
+ return out
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/swinir_arch.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/JingyunLiang/SwinIR
2
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
3
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
4
+
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ from r_basicsr.utils.registry import ARCH_REGISTRY
11
+ from .arch_util import to_2tuple, trunc_normal_
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
15
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
16
+
17
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
18
+ """
19
+ if drop_prob == 0. or not training:
20
+ return x
21
+ keep_prob = 1 - drop_prob
22
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
23
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
24
+ random_tensor.floor_() # binarize
25
+ output = x.div(keep_prob) * random_tensor
26
+ return output
27
+
28
+
29
+ class DropPath(nn.Module):
30
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
31
+
32
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
33
+ """
34
+
35
+ def __init__(self, drop_prob=None):
36
+ super(DropPath, self).__init__()
37
+ self.drop_prob = drop_prob
38
+
39
+ def forward(self, x):
40
+ return drop_path(x, self.drop_prob, self.training)
41
+
42
+
43
+ class Mlp(nn.Module):
44
+
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ x = self.drop(x)
58
+ x = self.fc2(x)
59
+ x = self.drop(x)
60
+ return x
61
+
62
+
63
+ def window_partition(x, window_size):
64
+ """
65
+ Args:
66
+ x: (b, h, w, c)
67
+ window_size (int): window size
68
+
69
+ Returns:
70
+ windows: (num_windows*b, window_size, window_size, c)
71
+ """
72
+ b, h, w, c = x.shape
73
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
74
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
75
+ return windows
76
+
77
+
78
+ def window_reverse(windows, window_size, h, w):
79
+ """
80
+ Args:
81
+ windows: (num_windows*b, window_size, window_size, c)
82
+ window_size (int): Window size
83
+ h (int): Height of image
84
+ w (int): Width of image
85
+
86
+ Returns:
87
+ x: (b, h, w, c)
88
+ """
89
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
90
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
91
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
92
+ return x
93
+
94
+
95
+ class WindowAttention(nn.Module):
96
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
97
+ It supports both of shifted and non-shifted window.
98
+
99
+ Args:
100
+ dim (int): Number of input channels.
101
+ window_size (tuple[int]): The height and width of the window.
102
+ num_heads (int): Number of attention heads.
103
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
104
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
105
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
106
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
107
+ """
108
+
109
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
110
+
111
+ super().__init__()
112
+ self.dim = dim
113
+ self.window_size = window_size # Wh, Ww
114
+ self.num_heads = num_heads
115
+ head_dim = dim // num_heads
116
+ self.scale = qk_scale or head_dim**-0.5
117
+
118
+ # define a parameter table of relative position bias
119
+ self.relative_position_bias_table = nn.Parameter(
120
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
121
+
122
+ # get pair-wise relative position index for each token inside the window
123
+ coords_h = torch.arange(self.window_size[0])
124
+ coords_w = torch.arange(self.window_size[1])
125
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
126
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
127
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
128
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
129
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
130
+ relative_coords[:, :, 1] += self.window_size[1] - 1
131
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
132
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
133
+ self.register_buffer('relative_position_index', relative_position_index)
134
+
135
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
136
+ self.attn_drop = nn.Dropout(attn_drop)
137
+ self.proj = nn.Linear(dim, dim)
138
+
139
+ self.proj_drop = nn.Dropout(proj_drop)
140
+
141
+ trunc_normal_(self.relative_position_bias_table, std=.02)
142
+ self.softmax = nn.Softmax(dim=-1)
143
+
144
+ def forward(self, x, mask=None):
145
+ """
146
+ Args:
147
+ x: input features with shape of (num_windows*b, n, c)
148
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
149
+ """
150
+ b_, n, c = x.shape
151
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
152
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
153
+
154
+ q = q * self.scale
155
+ attn = (q @ k.transpose(-2, -1))
156
+
157
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
158
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
159
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
160
+ attn = attn + relative_position_bias.unsqueeze(0)
161
+
162
+ if mask is not None:
163
+ nw = mask.shape[0]
164
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
165
+ attn = attn.view(-1, self.num_heads, n, n)
166
+ attn = self.softmax(attn)
167
+ else:
168
+ attn = self.softmax(attn)
169
+
170
+ attn = self.attn_drop(attn)
171
+
172
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
173
+ x = self.proj(x)
174
+ x = self.proj_drop(x)
175
+ return x
176
+
177
+ def extra_repr(self) -> str:
178
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
179
+
180
+ def flops(self, n):
181
+ # calculate flops for 1 window with token length of n
182
+ flops = 0
183
+ # qkv = self.qkv(x)
184
+ flops += n * self.dim * 3 * self.dim
185
+ # attn = (q @ k.transpose(-2, -1))
186
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
187
+ # x = (attn @ v)
188
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
189
+ # x = self.proj(x)
190
+ flops += n * self.dim * self.dim
191
+ return flops
192
+
193
+
194
+ class SwinTransformerBlock(nn.Module):
195
+ r""" Swin Transformer Block.
196
+
197
+ Args:
198
+ dim (int): Number of input channels.
199
+ input_resolution (tuple[int]): Input resolution.
200
+ num_heads (int): Number of attention heads.
201
+ window_size (int): Window size.
202
+ shift_size (int): Shift size for SW-MSA.
203
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
204
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
205
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
206
+ drop (float, optional): Dropout rate. Default: 0.0
207
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
208
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
209
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
210
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
211
+ """
212
+
213
+ def __init__(self,
214
+ dim,
215
+ input_resolution,
216
+ num_heads,
217
+ window_size=7,
218
+ shift_size=0,
219
+ mlp_ratio=4.,
220
+ qkv_bias=True,
221
+ qk_scale=None,
222
+ drop=0.,
223
+ attn_drop=0.,
224
+ drop_path=0.,
225
+ act_layer=nn.GELU,
226
+ norm_layer=nn.LayerNorm):
227
+ super().__init__()
228
+ self.dim = dim
229
+ self.input_resolution = input_resolution
230
+ self.num_heads = num_heads
231
+ self.window_size = window_size
232
+ self.shift_size = shift_size
233
+ self.mlp_ratio = mlp_ratio
234
+ if min(self.input_resolution) <= self.window_size:
235
+ # if window size is larger than input resolution, we don't partition windows
236
+ self.shift_size = 0
237
+ self.window_size = min(self.input_resolution)
238
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
239
+
240
+ self.norm1 = norm_layer(dim)
241
+ self.attn = WindowAttention(
242
+ dim,
243
+ window_size=to_2tuple(self.window_size),
244
+ num_heads=num_heads,
245
+ qkv_bias=qkv_bias,
246
+ qk_scale=qk_scale,
247
+ attn_drop=attn_drop,
248
+ proj_drop=drop)
249
+
250
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
251
+ self.norm2 = norm_layer(dim)
252
+ mlp_hidden_dim = int(dim * mlp_ratio)
253
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
254
+
255
+ if self.shift_size > 0:
256
+ attn_mask = self.calculate_mask(self.input_resolution)
257
+ else:
258
+ attn_mask = None
259
+
260
+ self.register_buffer('attn_mask', attn_mask)
261
+
262
+ def calculate_mask(self, x_size):
263
+ # calculate attention mask for SW-MSA
264
+ h, w = x_size
265
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
266
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
267
+ -self.shift_size), slice(-self.shift_size, None))
268
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
269
+ -self.shift_size), slice(-self.shift_size, None))
270
+ cnt = 0
271
+ for h in h_slices:
272
+ for w in w_slices:
273
+ img_mask[:, h, w, :] = cnt
274
+ cnt += 1
275
+
276
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
277
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
278
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
279
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
280
+
281
+ return attn_mask
282
+
283
+ def forward(self, x, x_size):
284
+ h, w = x_size
285
+ b, _, c = x.shape
286
+ # assert seq_len == h * w, "input feature has wrong size"
287
+
288
+ shortcut = x
289
+ x = self.norm1(x)
290
+ x = x.view(b, h, w, c)
291
+
292
+ # cyclic shift
293
+ if self.shift_size > 0:
294
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
295
+ else:
296
+ shifted_x = x
297
+
298
+ # partition windows
299
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
300
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
301
+
302
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
303
+ if self.input_resolution == x_size:
304
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
305
+ else:
306
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
307
+
308
+ # merge windows
309
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
310
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
311
+
312
+ # reverse cyclic shift
313
+ if self.shift_size > 0:
314
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
315
+ else:
316
+ x = shifted_x
317
+ x = x.view(b, h * w, c)
318
+
319
+ # FFN
320
+ x = shortcut + self.drop_path(x)
321
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
322
+
323
+ return x
324
+
325
+ def extra_repr(self) -> str:
326
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
327
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
328
+
329
+ def flops(self):
330
+ flops = 0
331
+ h, w = self.input_resolution
332
+ # norm1
333
+ flops += self.dim * h * w
334
+ # W-MSA/SW-MSA
335
+ nw = h * w / self.window_size / self.window_size
336
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
337
+ # mlp
338
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
339
+ # norm2
340
+ flops += self.dim * h * w
341
+ return flops
342
+
343
+
344
+ class PatchMerging(nn.Module):
345
+ r""" Patch Merging Layer.
346
+
347
+ Args:
348
+ input_resolution (tuple[int]): Resolution of input feature.
349
+ dim (int): Number of input channels.
350
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
351
+ """
352
+
353
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
354
+ super().__init__()
355
+ self.input_resolution = input_resolution
356
+ self.dim = dim
357
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
358
+ self.norm = norm_layer(4 * dim)
359
+
360
+ def forward(self, x):
361
+ """
362
+ x: b, h*w, c
363
+ """
364
+ h, w = self.input_resolution
365
+ b, seq_len, c = x.shape
366
+ assert seq_len == h * w, 'input feature has wrong size'
367
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
368
+
369
+ x = x.view(b, h, w, c)
370
+
371
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
372
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
373
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
374
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
375
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
376
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
377
+
378
+ x = self.norm(x)
379
+ x = self.reduction(x)
380
+
381
+ return x
382
+
383
+ def extra_repr(self) -> str:
384
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
385
+
386
+ def flops(self):
387
+ h, w = self.input_resolution
388
+ flops = h * w * self.dim
389
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
390
+ return flops
391
+
392
+
393
+ class BasicLayer(nn.Module):
394
+ """ A basic Swin Transformer layer for one stage.
395
+
396
+ Args:
397
+ dim (int): Number of input channels.
398
+ input_resolution (tuple[int]): Input resolution.
399
+ depth (int): Number of blocks.
400
+ num_heads (int): Number of attention heads.
401
+ window_size (int): Local window size.
402
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
403
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
404
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
405
+ drop (float, optional): Dropout rate. Default: 0.0
406
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
407
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
408
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
409
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
410
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
411
+ """
412
+
413
+ def __init__(self,
414
+ dim,
415
+ input_resolution,
416
+ depth,
417
+ num_heads,
418
+ window_size,
419
+ mlp_ratio=4.,
420
+ qkv_bias=True,
421
+ qk_scale=None,
422
+ drop=0.,
423
+ attn_drop=0.,
424
+ drop_path=0.,
425
+ norm_layer=nn.LayerNorm,
426
+ downsample=None,
427
+ use_checkpoint=False):
428
+
429
+ super().__init__()
430
+ self.dim = dim
431
+ self.input_resolution = input_resolution
432
+ self.depth = depth
433
+ self.use_checkpoint = use_checkpoint
434
+
435
+ # build blocks
436
+ self.blocks = nn.ModuleList([
437
+ SwinTransformerBlock(
438
+ dim=dim,
439
+ input_resolution=input_resolution,
440
+ num_heads=num_heads,
441
+ window_size=window_size,
442
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
443
+ mlp_ratio=mlp_ratio,
444
+ qkv_bias=qkv_bias,
445
+ qk_scale=qk_scale,
446
+ drop=drop,
447
+ attn_drop=attn_drop,
448
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
449
+ norm_layer=norm_layer) for i in range(depth)
450
+ ])
451
+
452
+ # patch merging layer
453
+ if downsample is not None:
454
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
455
+ else:
456
+ self.downsample = None
457
+
458
+ def forward(self, x, x_size):
459
+ for blk in self.blocks:
460
+ if self.use_checkpoint:
461
+ x = checkpoint.checkpoint(blk, x)
462
+ else:
463
+ x = blk(x, x_size)
464
+ if self.downsample is not None:
465
+ x = self.downsample(x)
466
+ return x
467
+
468
+ def extra_repr(self) -> str:
469
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
470
+
471
+ def flops(self):
472
+ flops = 0
473
+ for blk in self.blocks:
474
+ flops += blk.flops()
475
+ if self.downsample is not None:
476
+ flops += self.downsample.flops()
477
+ return flops
478
+
479
+
480
+ class RSTB(nn.Module):
481
+ """Residual Swin Transformer Block (RSTB).
482
+
483
+ Args:
484
+ dim (int): Number of input channels.
485
+ input_resolution (tuple[int]): Input resolution.
486
+ depth (int): Number of blocks.
487
+ num_heads (int): Number of attention heads.
488
+ window_size (int): Local window size.
489
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
490
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
491
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
492
+ drop (float, optional): Dropout rate. Default: 0.0
493
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
494
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
495
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
496
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
497
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
498
+ img_size: Input image size.
499
+ patch_size: Patch size.
500
+ resi_connection: The convolutional block before residual connection.
501
+ """
502
+
503
+ def __init__(self,
504
+ dim,
505
+ input_resolution,
506
+ depth,
507
+ num_heads,
508
+ window_size,
509
+ mlp_ratio=4.,
510
+ qkv_bias=True,
511
+ qk_scale=None,
512
+ drop=0.,
513
+ attn_drop=0.,
514
+ drop_path=0.,
515
+ norm_layer=nn.LayerNorm,
516
+ downsample=None,
517
+ use_checkpoint=False,
518
+ img_size=224,
519
+ patch_size=4,
520
+ resi_connection='1conv'):
521
+ super(RSTB, self).__init__()
522
+
523
+ self.dim = dim
524
+ self.input_resolution = input_resolution
525
+
526
+ self.residual_group = BasicLayer(
527
+ dim=dim,
528
+ input_resolution=input_resolution,
529
+ depth=depth,
530
+ num_heads=num_heads,
531
+ window_size=window_size,
532
+ mlp_ratio=mlp_ratio,
533
+ qkv_bias=qkv_bias,
534
+ qk_scale=qk_scale,
535
+ drop=drop,
536
+ attn_drop=attn_drop,
537
+ drop_path=drop_path,
538
+ norm_layer=norm_layer,
539
+ downsample=downsample,
540
+ use_checkpoint=use_checkpoint)
541
+
542
+ if resi_connection == '1conv':
543
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
544
+ elif resi_connection == '3conv':
545
+ # to save parameters and memory
546
+ self.conv = nn.Sequential(
547
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
548
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
549
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
550
+
551
+ self.patch_embed = PatchEmbed(
552
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
553
+
554
+ self.patch_unembed = PatchUnEmbed(
555
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
556
+
557
+ def forward(self, x, x_size):
558
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
559
+
560
+ def flops(self):
561
+ flops = 0
562
+ flops += self.residual_group.flops()
563
+ h, w = self.input_resolution
564
+ flops += h * w * self.dim * self.dim * 9
565
+ flops += self.patch_embed.flops()
566
+ flops += self.patch_unembed.flops()
567
+
568
+ return flops
569
+
570
+
571
+ class PatchEmbed(nn.Module):
572
+ r""" Image to Patch Embedding
573
+
574
+ Args:
575
+ img_size (int): Image size. Default: 224.
576
+ patch_size (int): Patch token size. Default: 4.
577
+ in_chans (int): Number of input image channels. Default: 3.
578
+ embed_dim (int): Number of linear projection output channels. Default: 96.
579
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
580
+ """
581
+
582
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
583
+ super().__init__()
584
+ img_size = to_2tuple(img_size)
585
+ patch_size = to_2tuple(patch_size)
586
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
587
+ self.img_size = img_size
588
+ self.patch_size = patch_size
589
+ self.patches_resolution = patches_resolution
590
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
591
+
592
+ self.in_chans = in_chans
593
+ self.embed_dim = embed_dim
594
+
595
+ if norm_layer is not None:
596
+ self.norm = norm_layer(embed_dim)
597
+ else:
598
+ self.norm = None
599
+
600
+ def forward(self, x):
601
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
602
+ if self.norm is not None:
603
+ x = self.norm(x)
604
+ return x
605
+
606
+ def flops(self):
607
+ flops = 0
608
+ h, w = self.img_size
609
+ if self.norm is not None:
610
+ flops += h * w * self.embed_dim
611
+ return flops
612
+
613
+
614
+ class PatchUnEmbed(nn.Module):
615
+ r""" Image to Patch Unembedding
616
+
617
+ Args:
618
+ img_size (int): Image size. Default: 224.
619
+ patch_size (int): Patch token size. Default: 4.
620
+ in_chans (int): Number of input image channels. Default: 3.
621
+ embed_dim (int): Number of linear projection output channels. Default: 96.
622
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
623
+ """
624
+
625
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
626
+ super().__init__()
627
+ img_size = to_2tuple(img_size)
628
+ patch_size = to_2tuple(patch_size)
629
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
630
+ self.img_size = img_size
631
+ self.patch_size = patch_size
632
+ self.patches_resolution = patches_resolution
633
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
634
+
635
+ self.in_chans = in_chans
636
+ self.embed_dim = embed_dim
637
+
638
+ def forward(self, x, x_size):
639
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
640
+ return x
641
+
642
+ def flops(self):
643
+ flops = 0
644
+ return flops
645
+
646
+
647
+ class Upsample(nn.Sequential):
648
+ """Upsample module.
649
+
650
+ Args:
651
+ scale (int): Scale factor. Supported scales: 2^n and 3.
652
+ num_feat (int): Channel number of intermediate features.
653
+ """
654
+
655
+ def __init__(self, scale, num_feat):
656
+ m = []
657
+ if (scale & (scale - 1)) == 0: # scale = 2^n
658
+ for _ in range(int(math.log(scale, 2))):
659
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
660
+ m.append(nn.PixelShuffle(2))
661
+ elif scale == 3:
662
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
663
+ m.append(nn.PixelShuffle(3))
664
+ else:
665
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
666
+ super(Upsample, self).__init__(*m)
667
+
668
+
669
+ class UpsampleOneStep(nn.Sequential):
670
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
671
+ Used in lightweight SR to save parameters.
672
+
673
+ Args:
674
+ scale (int): Scale factor. Supported scales: 2^n and 3.
675
+ num_feat (int): Channel number of intermediate features.
676
+
677
+ """
678
+
679
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
680
+ self.num_feat = num_feat
681
+ self.input_resolution = input_resolution
682
+ m = []
683
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
684
+ m.append(nn.PixelShuffle(scale))
685
+ super(UpsampleOneStep, self).__init__(*m)
686
+
687
+ def flops(self):
688
+ h, w = self.input_resolution
689
+ flops = h * w * self.num_feat * 3 * 9
690
+ return flops
691
+
692
+
693
+ @ARCH_REGISTRY.register()
694
+ class SwinIR(nn.Module):
695
+ r""" SwinIR
696
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
697
+
698
+ Args:
699
+ img_size (int | tuple(int)): Input image size. Default 64
700
+ patch_size (int | tuple(int)): Patch size. Default: 1
701
+ in_chans (int): Number of input image channels. Default: 3
702
+ embed_dim (int): Patch embedding dimension. Default: 96
703
+ depths (tuple(int)): Depth of each Swin Transformer layer.
704
+ num_heads (tuple(int)): Number of attention heads in different layers.
705
+ window_size (int): Window size. Default: 7
706
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
707
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
708
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
709
+ drop_rate (float): Dropout rate. Default: 0
710
+ attn_drop_rate (float): Attention dropout rate. Default: 0
711
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
712
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
713
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
714
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
715
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
716
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
717
+ img_range: Image range. 1. or 255.
718
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
719
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
720
+ """
721
+
722
+ def __init__(self,
723
+ img_size=64,
724
+ patch_size=1,
725
+ in_chans=3,
726
+ embed_dim=96,
727
+ depths=(6, 6, 6, 6),
728
+ num_heads=(6, 6, 6, 6),
729
+ window_size=7,
730
+ mlp_ratio=4.,
731
+ qkv_bias=True,
732
+ qk_scale=None,
733
+ drop_rate=0.,
734
+ attn_drop_rate=0.,
735
+ drop_path_rate=0.1,
736
+ norm_layer=nn.LayerNorm,
737
+ ape=False,
738
+ patch_norm=True,
739
+ use_checkpoint=False,
740
+ upscale=2,
741
+ img_range=1.,
742
+ upsampler='',
743
+ resi_connection='1conv',
744
+ **kwargs):
745
+ super(SwinIR, self).__init__()
746
+ num_in_ch = in_chans
747
+ num_out_ch = in_chans
748
+ num_feat = 64
749
+ self.img_range = img_range
750
+ if in_chans == 3:
751
+ rgb_mean = (0.4488, 0.4371, 0.4040)
752
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
753
+ else:
754
+ self.mean = torch.zeros(1, 1, 1, 1)
755
+ self.upscale = upscale
756
+ self.upsampler = upsampler
757
+
758
+ # ------------------------- 1, shallow feature extraction ------------------------- #
759
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
760
+
761
+ # ------------------------- 2, deep feature extraction ------------------------- #
762
+ self.num_layers = len(depths)
763
+ self.embed_dim = embed_dim
764
+ self.ape = ape
765
+ self.patch_norm = patch_norm
766
+ self.num_features = embed_dim
767
+ self.mlp_ratio = mlp_ratio
768
+
769
+ # split image into non-overlapping patches
770
+ self.patch_embed = PatchEmbed(
771
+ img_size=img_size,
772
+ patch_size=patch_size,
773
+ in_chans=embed_dim,
774
+ embed_dim=embed_dim,
775
+ norm_layer=norm_layer if self.patch_norm else None)
776
+ num_patches = self.patch_embed.num_patches
777
+ patches_resolution = self.patch_embed.patches_resolution
778
+ self.patches_resolution = patches_resolution
779
+
780
+ # merge non-overlapping patches into image
781
+ self.patch_unembed = PatchUnEmbed(
782
+ img_size=img_size,
783
+ patch_size=patch_size,
784
+ in_chans=embed_dim,
785
+ embed_dim=embed_dim,
786
+ norm_layer=norm_layer if self.patch_norm else None)
787
+
788
+ # absolute position embedding
789
+ if self.ape:
790
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
791
+ trunc_normal_(self.absolute_pos_embed, std=.02)
792
+
793
+ self.pos_drop = nn.Dropout(p=drop_rate)
794
+
795
+ # stochastic depth
796
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
797
+
798
+ # build Residual Swin Transformer blocks (RSTB)
799
+ self.layers = nn.ModuleList()
800
+ for i_layer in range(self.num_layers):
801
+ layer = RSTB(
802
+ dim=embed_dim,
803
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
804
+ depth=depths[i_layer],
805
+ num_heads=num_heads[i_layer],
806
+ window_size=window_size,
807
+ mlp_ratio=self.mlp_ratio,
808
+ qkv_bias=qkv_bias,
809
+ qk_scale=qk_scale,
810
+ drop=drop_rate,
811
+ attn_drop=attn_drop_rate,
812
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
813
+ norm_layer=norm_layer,
814
+ downsample=None,
815
+ use_checkpoint=use_checkpoint,
816
+ img_size=img_size,
817
+ patch_size=patch_size,
818
+ resi_connection=resi_connection)
819
+ self.layers.append(layer)
820
+ self.norm = norm_layer(self.num_features)
821
+
822
+ # build the last conv layer in deep feature extraction
823
+ if resi_connection == '1conv':
824
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
825
+ elif resi_connection == '3conv':
826
+ # to save parameters and memory
827
+ self.conv_after_body = nn.Sequential(
828
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
829
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
830
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
831
+
832
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
833
+ if self.upsampler == 'pixelshuffle':
834
+ # for classical SR
835
+ self.conv_before_upsample = nn.Sequential(
836
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
837
+ self.upsample = Upsample(upscale, num_feat)
838
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
839
+ elif self.upsampler == 'pixelshuffledirect':
840
+ # for lightweight SR (to save parameters)
841
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
842
+ (patches_resolution[0], patches_resolution[1]))
843
+ elif self.upsampler == 'nearest+conv':
844
+ # for real-world SR (less artifacts)
845
+ assert self.upscale == 4, 'only support x4 now.'
846
+ self.conv_before_upsample = nn.Sequential(
847
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
848
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
849
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
850
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
851
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
852
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
853
+ else:
854
+ # for image denoising and JPEG compression artifact reduction
855
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
856
+
857
+ self.apply(self._init_weights)
858
+
859
+ def _init_weights(self, m):
860
+ if isinstance(m, nn.Linear):
861
+ trunc_normal_(m.weight, std=.02)
862
+ if isinstance(m, nn.Linear) and m.bias is not None:
863
+ nn.init.constant_(m.bias, 0)
864
+ elif isinstance(m, nn.LayerNorm):
865
+ nn.init.constant_(m.bias, 0)
866
+ nn.init.constant_(m.weight, 1.0)
867
+
868
+ @torch.jit.ignore
869
+ def no_weight_decay(self):
870
+ return {'absolute_pos_embed'}
871
+
872
+ @torch.jit.ignore
873
+ def no_weight_decay_keywords(self):
874
+ return {'relative_position_bias_table'}
875
+
876
+ def forward_features(self, x):
877
+ x_size = (x.shape[2], x.shape[3])
878
+ x = self.patch_embed(x)
879
+ if self.ape:
880
+ x = x + self.absolute_pos_embed
881
+ x = self.pos_drop(x)
882
+
883
+ for layer in self.layers:
884
+ x = layer(x, x_size)
885
+
886
+ x = self.norm(x) # b seq_len c
887
+ x = self.patch_unembed(x, x_size)
888
+
889
+ return x
890
+
891
+ def forward(self, x):
892
+ self.mean = self.mean.type_as(x)
893
+ x = (x - self.mean) * self.img_range
894
+
895
+ if self.upsampler == 'pixelshuffle':
896
+ # for classical SR
897
+ x = self.conv_first(x)
898
+ x = self.conv_after_body(self.forward_features(x)) + x
899
+ x = self.conv_before_upsample(x)
900
+ x = self.conv_last(self.upsample(x))
901
+ elif self.upsampler == 'pixelshuffledirect':
902
+ # for lightweight SR
903
+ x = self.conv_first(x)
904
+ x = self.conv_after_body(self.forward_features(x)) + x
905
+ x = self.upsample(x)
906
+ elif self.upsampler == 'nearest+conv':
907
+ # for real-world SR
908
+ x = self.conv_first(x)
909
+ x = self.conv_after_body(self.forward_features(x)) + x
910
+ x = self.conv_before_upsample(x)
911
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
912
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
913
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
914
+ else:
915
+ # for image denoising and JPEG compression artifact reduction
916
+ x_first = self.conv_first(x)
917
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
918
+ x = x + self.conv_last(res)
919
+
920
+ x = x / self.img_range + self.mean
921
+
922
+ return x
923
+
924
+ def flops(self):
925
+ flops = 0
926
+ h, w = self.patches_resolution
927
+ flops += h * w * 3 * self.embed_dim * 9
928
+ flops += self.patch_embed.flops()
929
+ for layer in self.layers:
930
+ flops += layer.flops()
931
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
932
+ flops += self.upsample.flops()
933
+ return flops
934
+
935
+
936
+ if __name__ == '__main__':
937
+ upscale = 4
938
+ window_size = 8
939
+ height = (1024 // upscale // window_size + 1) * window_size
940
+ width = (720 // upscale // window_size + 1) * window_size
941
+ model = SwinIR(
942
+ upscale=2,
943
+ img_size=(height, width),
944
+ window_size=window_size,
945
+ img_range=1.,
946
+ depths=[6, 6, 6, 6],
947
+ embed_dim=60,
948
+ num_heads=[6, 6, 6, 6],
949
+ mlp_ratio=2,
950
+ upsampler='pixelshuffledirect')
951
+ print(model)
952
+ print(height, width, model.flops() / 1e9)
953
+
954
+ x = torch.randn((1, 3, height, width))
955
+ x = model(x)
956
+ print(x.shape)
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/tof_arch.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from r_basicsr.utils.registry import ARCH_REGISTRY
6
+ from .arch_util import flow_warp
7
+
8
+
9
+ class BasicModule(nn.Module):
10
+ """Basic module of SPyNet.
11
+
12
+ Note that unlike the architecture in spynet_arch.py, the basic module
13
+ here contains batch normalization.
14
+ """
15
+
16
+ def __init__(self):
17
+ super(BasicModule, self).__init__()
18
+ self.basic_module = nn.Sequential(
19
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
20
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
21
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
22
+ nn.BatchNorm2d(64), nn.ReLU(inplace=True),
23
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
24
+ nn.BatchNorm2d(32), nn.ReLU(inplace=True),
25
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
26
+ nn.BatchNorm2d(16), nn.ReLU(inplace=True),
27
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
28
+
29
+ def forward(self, tensor_input):
30
+ """
31
+ Args:
32
+ tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
33
+ 8 channels contain:
34
+ [reference image (3), neighbor image (3), initial flow (2)].
35
+
36
+ Returns:
37
+ Tensor: Estimated flow with shape (b, 2, h, w)
38
+ """
39
+ return self.basic_module(tensor_input)
40
+
41
+
42
+ class SPyNetTOF(nn.Module):
43
+ """SPyNet architecture for TOF.
44
+
45
+ Note that this implementation is specifically for TOFlow. Please use
46
+ spynet_arch.py for general use. They differ in the following aspects:
47
+ 1. The basic modules here contain BatchNorm.
48
+ 2. Normalization and denormalization are not done here, as
49
+ they are done in TOFlow.
50
+ Paper:
51
+ Optical Flow Estimation using a Spatial Pyramid Network
52
+ Code reference:
53
+ https://github.com/Coldog2333/pytoflow
54
+
55
+ Args:
56
+ load_path (str): Path for pretrained SPyNet. Default: None.
57
+ """
58
+
59
+ def __init__(self, load_path=None):
60
+ super(SPyNetTOF, self).__init__()
61
+
62
+ self.basic_module = nn.ModuleList([BasicModule() for _ in range(4)])
63
+ if load_path:
64
+ self.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage)['params'])
65
+
66
+ def forward(self, ref, supp):
67
+ """
68
+ Args:
69
+ ref (Tensor): Reference image with shape of (b, 3, h, w).
70
+ supp: The supporting image to be warped: (b, 3, h, w).
71
+
72
+ Returns:
73
+ Tensor: Estimated optical flow: (b, 2, h, w).
74
+ """
75
+ num_batches, _, h, w = ref.size()
76
+ ref = [ref]
77
+ supp = [supp]
78
+
79
+ # generate downsampled frames
80
+ for _ in range(3):
81
+ ref.insert(0, F.avg_pool2d(input=ref[0], kernel_size=2, stride=2, count_include_pad=False))
82
+ supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
83
+
84
+ # flow computation
85
+ flow = ref[0].new_zeros(num_batches, 2, h // 16, w // 16)
86
+ for i in range(4):
87
+ flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
88
+ flow = flow_up + self.basic_module[i](
89
+ torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
90
+ return flow
91
+
92
+
93
+ @ARCH_REGISTRY.register()
94
+ class TOFlow(nn.Module):
95
+ """PyTorch implementation of TOFlow.
96
+
97
+ In TOFlow, the LR frames are pre-upsampled and have the same size with
98
+ the GT frames.
99
+ Paper:
100
+ Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
101
+ Code reference:
102
+ 1. https://github.com/anchen1011/toflow
103
+ 2. https://github.com/Coldog2333/pytoflow
104
+
105
+ Args:
106
+ adapt_official_weights (bool): Whether to adapt the weights translated
107
+ from the official implementation. Set to false if you want to
108
+ train from scratch. Default: False
109
+ """
110
+
111
+ def __init__(self, adapt_official_weights=False):
112
+ super(TOFlow, self).__init__()
113
+ self.adapt_official_weights = adapt_official_weights
114
+ self.ref_idx = 0 if adapt_official_weights else 3
115
+
116
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
117
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
118
+
119
+ # flow estimation module
120
+ self.spynet = SPyNetTOF()
121
+
122
+ # reconstruction module
123
+ self.conv_1 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
124
+ self.conv_2 = nn.Conv2d(64, 64, 9, 1, 4)
125
+ self.conv_3 = nn.Conv2d(64, 64, 1)
126
+ self.conv_4 = nn.Conv2d(64, 3, 1)
127
+
128
+ # activation function
129
+ self.relu = nn.ReLU(inplace=True)
130
+
131
+ def normalize(self, img):
132
+ return (img - self.mean) / self.std
133
+
134
+ def denormalize(self, img):
135
+ return img * self.std + self.mean
136
+
137
+ def forward(self, lrs):
138
+ """
139
+ Args:
140
+ lrs: Input lr frames: (b, 7, 3, h, w).
141
+
142
+ Returns:
143
+ Tensor: SR frame: (b, 3, h, w).
144
+ """
145
+ # In the official implementation, the 0-th frame is the reference frame
146
+ if self.adapt_official_weights:
147
+ lrs = lrs[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
148
+
149
+ num_batches, num_lrs, _, h, w = lrs.size()
150
+
151
+ lrs = self.normalize(lrs.view(-1, 3, h, w))
152
+ lrs = lrs.view(num_batches, num_lrs, 3, h, w)
153
+
154
+ lr_ref = lrs[:, self.ref_idx, :, :, :]
155
+ lr_aligned = []
156
+ for i in range(7): # 7 frames
157
+ if i == self.ref_idx:
158
+ lr_aligned.append(lr_ref)
159
+ else:
160
+ lr_supp = lrs[:, i, :, :, :]
161
+ flow = self.spynet(lr_ref, lr_supp)
162
+ lr_aligned.append(flow_warp(lr_supp, flow.permute(0, 2, 3, 1)))
163
+
164
+ # reconstruction
165
+ hr = torch.stack(lr_aligned, dim=1)
166
+ hr = hr.view(num_batches, -1, h, w)
167
+ hr = self.relu(self.conv_1(hr))
168
+ hr = self.relu(self.conv_2(hr))
169
+ hr = self.relu(self.conv_3(hr))
170
+ hr = self.conv_4(hr) + lr_ref
171
+
172
+ return self.denormalize(hr)
custom_nodes/ComfyUI-ReActor/r_basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from r_basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10
+ NAMES = {
11
+ 'vgg11': [
12
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14
+ 'pool5'
15
+ ],
16
+ 'vgg13': [
17
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20
+ ],
21
+ 'vgg16': [
22
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25
+ 'pool5'
26
+ ],
27
+ 'vgg19': [
28
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32
+ ]
33
+ }
34
+
35
+
36
+ def insert_bn(names):
37
+ """Insert bn layer after each conv.
38
+
39
+ Args:
40
+ names (list): The list of layer names.
41
+
42
+ Returns:
43
+ list: The list of layer names with bn layers.
44
+ """
45
+ names_bn = []
46
+ for name in names:
47
+ names_bn.append(name)
48
+ if 'conv' in name:
49
+ position = name.replace('conv', '')
50
+ names_bn.append('bn' + position)
51
+ return names_bn
52
+
53
+
54
+ @ARCH_REGISTRY.register()
55
+ class VGGFeatureExtractor(nn.Module):
56
+ """VGG network for feature extraction.
57
+
58
+ In this implementation, we allow users to choose whether use normalization
59
+ in the input feature and the type of vgg network. Note that the pretrained
60
+ path must fit the vgg type.
61
+
62
+ Args:
63
+ layer_name_list (list[str]): Forward function returns the corresponding
64
+ features according to the layer_name_list.
65
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67
+ use_input_norm (bool): If True, normalize the input image. Importantly,
68
+ the input feature must in the range [0, 1]. Default: True.
69
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70
+ Default: False.
71
+ requires_grad (bool): If true, the parameters of VGG network will be
72
+ optimized. Default: False.
73
+ remove_pooling (bool): If true, the max pooling operations in VGG net
74
+ will be removed. Default: False.
75
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
76
+ """
77
+
78
+ def __init__(self,
79
+ layer_name_list,
80
+ vgg_type='vgg19',
81
+ use_input_norm=True,
82
+ range_norm=False,
83
+ requires_grad=False,
84
+ remove_pooling=False,
85
+ pooling_stride=2):
86
+ super(VGGFeatureExtractor, self).__init__()
87
+
88
+ self.layer_name_list = layer_name_list
89
+ self.use_input_norm = use_input_norm
90
+ self.range_norm = range_norm
91
+
92
+ self.names = NAMES[vgg_type.replace('_bn', '')]
93
+ if 'bn' in vgg_type:
94
+ self.names = insert_bn(self.names)
95
+
96
+ # only borrow layers that will be used to avoid unused params
97
+ max_idx = 0
98
+ for v in layer_name_list:
99
+ idx = self.names.index(v)
100
+ if idx > max_idx:
101
+ max_idx = idx
102
+
103
+ if os.path.exists(VGG_PRETRAIN_PATH):
104
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106
+ vgg_net.load_state_dict(state_dict)
107
+ else:
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109
+
110
+ features = vgg_net.features[:max_idx + 1]
111
+
112
+ modified_net = OrderedDict()
113
+ for k, v in zip(self.names, features):
114
+ if 'pool' in k:
115
+ # if remove_pooling is true, pooling operation will be removed
116
+ if remove_pooling:
117
+ continue
118
+ else:
119
+ # in some cases, we may want to change the default stride
120
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121
+ else:
122
+ modified_net[k] = v
123
+
124
+ self.vgg_net = nn.Sequential(modified_net)
125
+
126
+ if not requires_grad:
127
+ self.vgg_net.eval()
128
+ for param in self.parameters():
129
+ param.requires_grad = False
130
+ else:
131
+ self.vgg_net.train()
132
+ for param in self.parameters():
133
+ param.requires_grad = True
134
+
135
+ if self.use_input_norm:
136
+ # the mean is for image with range [0, 1]
137
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138
+ # the std is for image with range [0, 1]
139
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140
+
141
+ def forward(self, x):
142
+ """Forward function.
143
+
144
+ Args:
145
+ x (Tensor): Input tensor with shape (n, c, h, w).
146
+
147
+ Returns:
148
+ Tensor: Forward results.
149
+ """
150
+ if self.range_norm:
151
+ x = (x + 1) / 2
152
+ if self.use_input_norm:
153
+ x = (x - self.mean) / self.std
154
+
155
+ output = {}
156
+ for key, layer in self.vgg_net._modules.items():
157
+ x = layer(x)
158
+ if key in self.layer_name_list:
159
+ output[key] = x.clone()
160
+
161
+ return output
custom_nodes/ComfyUI-ReActor/r_basicsr/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from r_basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from r_basicsr.utils import get_root_logger, scandir
12
+ from r_basicsr.utils.dist_util import get_dist_info
13
+ from r_basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'r_basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
custom_nodes/ComfyUI-ReActor/r_basicsr/data/data_util.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from r_basicsr.data.transforms import mod_crop
8
+ from r_basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+ return_imgname(bool): Whether return image names. Default False.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ list[str]: Returned image name list.
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+
31
+ if require_mod_crop:
32
+ imgs = [mod_crop(img, scale) for img in imgs]
33
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
+ imgs = torch.stack(imgs, dim=0)
35
+
36
+ if return_imgname:
37
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
+ return imgs, imgnames
39
+ else:
40
+ return imgs
41
+
42
+
43
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
+ """Generate an index list for reading `num_frames` frames from a sequence
45
+ of images.
46
+
47
+ Args:
48
+ crt_idx (int): Current center index.
49
+ max_frame_num (int): Max number of the sequence of images (from 1).
50
+ num_frames (int): Reading num_frames frames.
51
+ padding (str): Padding mode, one of
52
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
+ Examples: current_idx = 0, num_frames = 5
54
+ The generated frame indices under different padding mode:
55
+ replicate: [0, 0, 0, 1, 2]
56
+ reflection: [2, 1, 0, 1, 2]
57
+ reflection_circle: [4, 3, 0, 1, 2]
58
+ circle: [3, 4, 0, 1, 2]
59
+
60
+ Returns:
61
+ list[int]: A list of indices.
62
+ """
63
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
+
66
+ max_frame_num = max_frame_num - 1 # start from 0
67
+ num_pad = num_frames // 2
68
+
69
+ indices = []
70
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
+ if i < 0:
72
+ if padding == 'replicate':
73
+ pad_idx = 0
74
+ elif padding == 'reflection':
75
+ pad_idx = -i
76
+ elif padding == 'reflection_circle':
77
+ pad_idx = crt_idx + num_pad - i
78
+ else:
79
+ pad_idx = num_frames + i
80
+ elif i > max_frame_num:
81
+ if padding == 'replicate':
82
+ pad_idx = max_frame_num
83
+ elif padding == 'reflection':
84
+ pad_idx = max_frame_num * 2 - i
85
+ elif padding == 'reflection_circle':
86
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
+ else:
88
+ pad_idx = i - num_frames
89
+ else:
90
+ pad_idx = i
91
+ indices.append(pad_idx)
92
+ return indices
93
+
94
+
95
+ def paired_paths_from_lmdb(folders, keys):
96
+ """Generate paired paths from lmdb files.
97
+
98
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
+
100
+ lq.lmdb
101
+ ├── data.mdb
102
+ ├── lock.mdb
103
+ ├── meta_info.txt
104
+
105
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
106
+ https://lmdb.readthedocs.io/en/release/ for more details.
107
+
108
+ The meta_info.txt is a specified txt file to record the meta information
109
+ of our datasets. It will be automatically created when preparing
110
+ datasets by our provided dataset tools.
111
+ Each line in the txt file records
112
+ 1)image name (with extension),
113
+ 2)image shape,
114
+ 3)compression level, separated by a white space.
115
+ Example: `baboon.png (120,125,3) 1`
116
+
117
+ We use the image name without extension as the lmdb key.
118
+ Note that we use the same key for the corresponding lq and gt images.
119
+
120
+ Args:
121
+ folders (list[str]): A list of folder path. The order of list should
122
+ be [input_folder, gt_folder].
123
+ keys (list[str]): A list of keys identifying folders. The order should
124
+ be in consistent with folders, e.g., ['lq', 'gt'].
125
+ Note that this key is different from lmdb keys.
126
+
127
+ Returns:
128
+ list[str]: Returned path list.
129
+ """
130
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
131
+ f'But got {len(folders)}')
132
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
133
+ input_folder, gt_folder = folders
134
+ input_key, gt_key = keys
135
+
136
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
137
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
138
+ f'formats. But received {input_key}: {input_folder}; '
139
+ f'{gt_key}: {gt_folder}')
140
+ # ensure that the two meta_info files are the same
141
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
142
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
143
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
144
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
145
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
146
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
147
+ else:
148
+ paths = []
149
+ for lmdb_key in sorted(input_lmdb_keys):
150
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
151
+ return paths
152
+
153
+
154
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
155
+ """Generate paired paths from an meta information file.
156
+
157
+ Each line in the meta information file contains the image names and
158
+ image shape (usually for gt), separated by a white space.
159
+
160
+ Example of an meta information file:
161
+ ```
162
+ 0001_s001.png (480,480,3)
163
+ 0001_s002.png (480,480,3)
164
+ ```
165
+
166
+ Args:
167
+ folders (list[str]): A list of folder path. The order of list should
168
+ be [input_folder, gt_folder].
169
+ keys (list[str]): A list of keys identifying folders. The order should
170
+ be in consistent with folders, e.g., ['lq', 'gt'].
171
+ meta_info_file (str): Path to the meta information file.
172
+ filename_tmpl (str): Template for each filename. Note that the
173
+ template excludes the file extension. Usually the filename_tmpl is
174
+ for files in the input folder.
175
+
176
+ Returns:
177
+ list[str]: Returned path list.
178
+ """
179
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
180
+ f'But got {len(folders)}')
181
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
182
+ input_folder, gt_folder = folders
183
+ input_key, gt_key = keys
184
+
185
+ with open(meta_info_file, 'r') as fin:
186
+ gt_names = [line.strip().split(' ')[0] for line in fin]
187
+
188
+ paths = []
189
+ for gt_name in gt_names:
190
+ basename, ext = osp.splitext(osp.basename(gt_name))
191
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
192
+ input_path = osp.join(input_folder, input_name)
193
+ gt_path = osp.join(gt_folder, gt_name)
194
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
195
+ return paths
196
+
197
+
198
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
199
+ """Generate paired paths from folders.
200
+
201
+ Args:
202
+ folders (list[str]): A list of folder path. The order of list should
203
+ be [input_folder, gt_folder].
204
+ keys (list[str]): A list of keys identifying folders. The order should
205
+ be in consistent with folders, e.g., ['lq', 'gt'].
206
+ filename_tmpl (str): Template for each filename. Note that the
207
+ template excludes the file extension. Usually the filename_tmpl is
208
+ for files in the input folder.
209
+
210
+ Returns:
211
+ list[str]: Returned path list.
212
+ """
213
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
214
+ f'But got {len(folders)}')
215
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
216
+ input_folder, gt_folder = folders
217
+ input_key, gt_key = keys
218
+
219
+ input_paths = list(scandir(input_folder))
220
+ gt_paths = list(scandir(gt_folder))
221
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
222
+ f'{len(input_paths)}, {len(gt_paths)}.')
223
+ paths = []
224
+ for gt_path in gt_paths:
225
+ basename, ext = osp.splitext(osp.basename(gt_path))
226
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
227
+ input_path = osp.join(input_folder, input_name)
228
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
229
+ gt_path = osp.join(gt_folder, gt_path)
230
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
231
+ return paths
232
+
233
+
234
+ def paths_from_folder(folder):
235
+ """Generate paths from folder.
236
+
237
+ Args:
238
+ folder (str): Folder path.
239
+
240
+ Returns:
241
+ list[str]: Returned path list.
242
+ """
243
+
244
+ paths = list(scandir(folder))
245
+ paths = [osp.join(folder, path) for path in paths]
246
+ return paths
247
+
248
+
249
+ def paths_from_lmdb(folder):
250
+ """Generate paths from lmdb.
251
+
252
+ Args:
253
+ folder (str): Folder path.
254
+
255
+ Returns:
256
+ list[str]: Returned path list.
257
+ """
258
+ if not folder.endswith('.lmdb'):
259
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
260
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
261
+ paths = [line.split('.')[0] for line in fin]
262
+ return paths
263
+
264
+
265
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
266
+ """Generate Gaussian kernel used in `duf_downsample`.
267
+
268
+ Args:
269
+ kernel_size (int): Kernel size. Default: 13.
270
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
271
+
272
+ Returns:
273
+ np.array: The Gaussian kernel.
274
+ """
275
+ from scipy.ndimage import filters as filters
276
+ kernel = np.zeros((kernel_size, kernel_size))
277
+ # set element at the middle to one, a dirac delta
278
+ kernel[kernel_size // 2, kernel_size // 2] = 1
279
+ # gaussian-smooth the dirac, resulting in a gaussian filter
280
+ return filters.gaussian_filter(kernel, sigma)
281
+
282
+
283
+ def duf_downsample(x, kernel_size=13, scale=4):
284
+ """Downsamping with Gaussian kernel used in the DUF official code.
285
+
286
+ Args:
287
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
288
+ kernel_size (int): Kernel size. Default: 13.
289
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
290
+ Default: 4.
291
+
292
+ Returns:
293
+ Tensor: DUF downsampled frames.
294
+ """
295
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
296
+
297
+ squeeze_flag = False
298
+ if x.ndim == 4:
299
+ squeeze_flag = True
300
+ x = x.unsqueeze(0)
301
+ b, t, c, h, w = x.size()
302
+ x = x.view(-1, 1, h, w)
303
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
304
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
305
+
306
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
307
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
308
+ x = F.conv2d(x, gaussian_filter, stride=scale)
309
+ x = x[:, :, 2:-2, 2:-2]
310
+ x = x.view(b, t, c, x.size(2), x.size(3))
311
+ if squeeze_flag:
312
+ x = x.squeeze(0)
313
+ return x
custom_nodes/ComfyUI-ReActor/r_basicsr/data/degradations.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from scipy import special
7
+ from scipy.stats import multivariate_normal
8
+ try:
9
+ from torchvision.transforms.functional_tensor import rgb_to_grayscale
10
+ except:
11
+ from torchvision.transforms.functional import rgb_to_grayscale
12
+
13
+ # -------------------------------------------------------------------- #
14
+ # --------------------------- blur kernels --------------------------- #
15
+ # -------------------------------------------------------------------- #
16
+
17
+
18
+ # --------------------------- util functions --------------------------- #
19
+ def sigma_matrix2(sig_x, sig_y, theta):
20
+ """Calculate the rotated sigma matrix (two dimensional matrix).
21
+
22
+ Args:
23
+ sig_x (float):
24
+ sig_y (float):
25
+ theta (float): Radian measurement.
26
+
27
+ Returns:
28
+ ndarray: Rotated sigma matrix.
29
+ """
30
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
31
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
32
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
33
+
34
+
35
+ def mesh_grid(kernel_size):
36
+ """Generate the mesh grid, centering at zero.
37
+
38
+ Args:
39
+ kernel_size (int):
40
+
41
+ Returns:
42
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
43
+ xx (ndarray): with the shape (kernel_size, kernel_size)
44
+ yy (ndarray): with the shape (kernel_size, kernel_size)
45
+ """
46
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
47
+ xx, yy = np.meshgrid(ax, ax)
48
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
49
+ 1))).reshape(kernel_size, kernel_size, 2)
50
+ return xy, xx, yy
51
+
52
+
53
+ def pdf2(sigma_matrix, grid):
54
+ """Calculate PDF of the bivariate Gaussian distribution.
55
+
56
+ Args:
57
+ sigma_matrix (ndarray): with the shape (2, 2)
58
+ grid (ndarray): generated by :func:`mesh_grid`,
59
+ with the shape (K, K, 2), K is the kernel size.
60
+
61
+ Returns:
62
+ kernel (ndarrray): un-normalized kernel.
63
+ """
64
+ inverse_sigma = np.linalg.inv(sigma_matrix)
65
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
66
+ return kernel
67
+
68
+
69
+ def cdf2(d_matrix, grid):
70
+ """Calculate the CDF of the standard bivariate Gaussian distribution.
71
+ Used in skewed Gaussian distribution.
72
+
73
+ Args:
74
+ d_matrix (ndarrasy): skew matrix.
75
+ grid (ndarray): generated by :func:`mesh_grid`,
76
+ with the shape (K, K, 2), K is the kernel size.
77
+
78
+ Returns:
79
+ cdf (ndarray): skewed cdf.
80
+ """
81
+ rv = multivariate_normal([0, 0], [[1, 0], [0, 1]])
82
+ grid = np.dot(grid, d_matrix)
83
+ cdf = rv.cdf(grid)
84
+ return cdf
85
+
86
+
87
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
88
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
89
+
90
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
91
+
92
+ Args:
93
+ kernel_size (int):
94
+ sig_x (float):
95
+ sig_y (float):
96
+ theta (float): Radian measurement.
97
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
98
+ with the shape (K, K, 2), K is the kernel size. Default: None
99
+ isotropic (bool):
100
+
101
+ Returns:
102
+ kernel (ndarray): normalized kernel.
103
+ """
104
+ if grid is None:
105
+ grid, _, _ = mesh_grid(kernel_size)
106
+ if isotropic:
107
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
108
+ else:
109
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
110
+ kernel = pdf2(sigma_matrix, grid)
111
+ kernel = kernel / np.sum(kernel)
112
+ return kernel
113
+
114
+
115
+ def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
116
+ """Generate a bivariate generalized Gaussian kernel.
117
+ Described in `Parameter Estimation For Multivariate Generalized
118
+ Gaussian Distributions`_
119
+ by Pascal et. al (2013).
120
+
121
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
122
+
123
+ Args:
124
+ kernel_size (int):
125
+ sig_x (float):
126
+ sig_y (float):
127
+ theta (float): Radian measurement.
128
+ beta (float): shape parameter, beta = 1 is the normal distribution.
129
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
130
+ with the shape (K, K, 2), K is the kernel size. Default: None
131
+
132
+ Returns:
133
+ kernel (ndarray): normalized kernel.
134
+
135
+ .. _Parameter Estimation For Multivariate Generalized Gaussian
136
+ Distributions: https://arxiv.org/abs/1302.6498
137
+ """
138
+ if grid is None:
139
+ grid, _, _ = mesh_grid(kernel_size)
140
+ if isotropic:
141
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
142
+ else:
143
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
144
+ inverse_sigma = np.linalg.inv(sigma_matrix)
145
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
146
+ kernel = kernel / np.sum(kernel)
147
+ return kernel
148
+
149
+
150
+ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True):
151
+ """Generate a plateau-like anisotropic kernel.
152
+ 1 / (1+x^(beta))
153
+
154
+ Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution
155
+
156
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
157
+
158
+ Args:
159
+ kernel_size (int):
160
+ sig_x (float):
161
+ sig_y (float):
162
+ theta (float): Radian measurement.
163
+ beta (float): shape parameter, beta = 1 is the normal distribution.
164
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
165
+ with the shape (K, K, 2), K is the kernel size. Default: None
166
+
167
+ Returns:
168
+ kernel (ndarray): normalized kernel.
169
+ """
170
+ if grid is None:
171
+ grid, _, _ = mesh_grid(kernel_size)
172
+ if isotropic:
173
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
174
+ else:
175
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
176
+ inverse_sigma = np.linalg.inv(sigma_matrix)
177
+ kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1)
178
+ kernel = kernel / np.sum(kernel)
179
+ return kernel
180
+
181
+
182
+ def random_bivariate_Gaussian(kernel_size,
183
+ sigma_x_range,
184
+ sigma_y_range,
185
+ rotation_range,
186
+ noise_range=None,
187
+ isotropic=True):
188
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
189
+
190
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
191
+
192
+ Args:
193
+ kernel_size (int):
194
+ sigma_x_range (tuple): [0.6, 5]
195
+ sigma_y_range (tuple): [0.6, 5]
196
+ rotation range (tuple): [-math.pi, math.pi]
197
+ noise_range(tuple, optional): multiplicative kernel noise,
198
+ [0.75, 1.25]. Default: None
199
+
200
+ Returns:
201
+ kernel (ndarray):
202
+ """
203
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
204
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
205
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
206
+ if isotropic is False:
207
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
208
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
209
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
210
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
211
+ else:
212
+ sigma_y = sigma_x
213
+ rotation = 0
214
+
215
+ kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic)
216
+
217
+ # add multiplicative noise
218
+ if noise_range is not None:
219
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
220
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
221
+ kernel = kernel * noise
222
+ kernel = kernel / np.sum(kernel)
223
+ return kernel
224
+
225
+
226
+ def random_bivariate_generalized_Gaussian(kernel_size,
227
+ sigma_x_range,
228
+ sigma_y_range,
229
+ rotation_range,
230
+ beta_range,
231
+ noise_range=None,
232
+ isotropic=True):
233
+ """Randomly generate bivariate generalized Gaussian kernels.
234
+
235
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
236
+
237
+ Args:
238
+ kernel_size (int):
239
+ sigma_x_range (tuple): [0.6, 5]
240
+ sigma_y_range (tuple): [0.6, 5]
241
+ rotation range (tuple): [-math.pi, math.pi]
242
+ beta_range (tuple): [0.5, 8]
243
+ noise_range(tuple, optional): multiplicative kernel noise,
244
+ [0.75, 1.25]. Default: None
245
+
246
+ Returns:
247
+ kernel (ndarray):
248
+ """
249
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
250
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
251
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
252
+ if isotropic is False:
253
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
254
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
255
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
256
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
257
+ else:
258
+ sigma_y = sigma_x
259
+ rotation = 0
260
+
261
+ # assume beta_range[0] < 1 < beta_range[1]
262
+ if np.random.uniform() < 0.5:
263
+ beta = np.random.uniform(beta_range[0], 1)
264
+ else:
265
+ beta = np.random.uniform(1, beta_range[1])
266
+
267
+ kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
268
+
269
+ # add multiplicative noise
270
+ if noise_range is not None:
271
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
272
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
273
+ kernel = kernel * noise
274
+ kernel = kernel / np.sum(kernel)
275
+ return kernel
276
+
277
+
278
+ def random_bivariate_plateau(kernel_size,
279
+ sigma_x_range,
280
+ sigma_y_range,
281
+ rotation_range,
282
+ beta_range,
283
+ noise_range=None,
284
+ isotropic=True):
285
+ """Randomly generate bivariate plateau kernels.
286
+
287
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
288
+
289
+ Args:
290
+ kernel_size (int):
291
+ sigma_x_range (tuple): [0.6, 5]
292
+ sigma_y_range (tuple): [0.6, 5]
293
+ rotation range (tuple): [-math.pi/2, math.pi/2]
294
+ beta_range (tuple): [1, 4]
295
+ noise_range(tuple, optional): multiplicative kernel noise,
296
+ [0.75, 1.25]. Default: None
297
+
298
+ Returns:
299
+ kernel (ndarray):
300
+ """
301
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
302
+ assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.'
303
+ sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1])
304
+ if isotropic is False:
305
+ assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.'
306
+ assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.'
307
+ sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1])
308
+ rotation = np.random.uniform(rotation_range[0], rotation_range[1])
309
+ else:
310
+ sigma_y = sigma_x
311
+ rotation = 0
312
+
313
+ # TODO: this may be not proper
314
+ if np.random.uniform() < 0.5:
315
+ beta = np.random.uniform(beta_range[0], 1)
316
+ else:
317
+ beta = np.random.uniform(1, beta_range[1])
318
+
319
+ kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic)
320
+ # add multiplicative noise
321
+ if noise_range is not None:
322
+ assert noise_range[0] < noise_range[1], 'Wrong noise range.'
323
+ noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape)
324
+ kernel = kernel * noise
325
+ kernel = kernel / np.sum(kernel)
326
+
327
+ return kernel
328
+
329
+
330
+ def random_mixed_kernels(kernel_list,
331
+ kernel_prob,
332
+ kernel_size=21,
333
+ sigma_x_range=(0.6, 5),
334
+ sigma_y_range=(0.6, 5),
335
+ rotation_range=(-math.pi, math.pi),
336
+ betag_range=(0.5, 8),
337
+ betap_range=(0.5, 8),
338
+ noise_range=None):
339
+ """Randomly generate mixed kernels.
340
+
341
+ Args:
342
+ kernel_list (tuple): a list name of kernel types,
343
+ support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso',
344
+ 'plateau_aniso']
345
+ kernel_prob (tuple): corresponding kernel probability for each
346
+ kernel type
347
+ kernel_size (int):
348
+ sigma_x_range (tuple): [0.6, 5]
349
+ sigma_y_range (tuple): [0.6, 5]
350
+ rotation range (tuple): [-math.pi, math.pi]
351
+ beta_range (tuple): [0.5, 8]
352
+ noise_range(tuple, optional): multiplicative kernel noise,
353
+ [0.75, 1.25]. Default: None
354
+
355
+ Returns:
356
+ kernel (ndarray):
357
+ """
358
+ kernel_type = random.choices(kernel_list, kernel_prob)[0]
359
+ if kernel_type == 'iso':
360
+ kernel = random_bivariate_Gaussian(
361
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
362
+ elif kernel_type == 'aniso':
363
+ kernel = random_bivariate_Gaussian(
364
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
365
+ elif kernel_type == 'generalized_iso':
366
+ kernel = random_bivariate_generalized_Gaussian(
367
+ kernel_size,
368
+ sigma_x_range,
369
+ sigma_y_range,
370
+ rotation_range,
371
+ betag_range,
372
+ noise_range=noise_range,
373
+ isotropic=True)
374
+ elif kernel_type == 'generalized_aniso':
375
+ kernel = random_bivariate_generalized_Gaussian(
376
+ kernel_size,
377
+ sigma_x_range,
378
+ sigma_y_range,
379
+ rotation_range,
380
+ betag_range,
381
+ noise_range=noise_range,
382
+ isotropic=False)
383
+ elif kernel_type == 'plateau_iso':
384
+ kernel = random_bivariate_plateau(
385
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
386
+ elif kernel_type == 'plateau_aniso':
387
+ kernel = random_bivariate_plateau(
388
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
389
+ return kernel
390
+
391
+
392
+ np.seterr(divide='ignore', invalid='ignore')
393
+
394
+
395
+ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
396
+ """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
397
+
398
+ Args:
399
+ cutoff (float): cutoff frequency in radians (pi is max)
400
+ kernel_size (int): horizontal and vertical size, must be odd.
401
+ pad_to (int): pad kernel size to desired size, must be odd or zero.
402
+ """
403
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
404
+ kernel = np.fromfunction(
405
+ lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
406
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
407
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
408
+ kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
409
+ kernel = kernel / np.sum(kernel)
410
+ if pad_to > kernel_size:
411
+ pad_size = (pad_to - kernel_size) // 2
412
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
413
+ return kernel
414
+
415
+
416
+ # ------------------------------------------------------------- #
417
+ # --------------------------- noise --------------------------- #
418
+ # ------------------------------------------------------------- #
419
+
420
+ # ----------------------- Gaussian Noise ----------------------- #
421
+
422
+
423
+ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
424
+ """Generate Gaussian noise.
425
+
426
+ Args:
427
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
428
+ sigma (float): Noise scale (measured in range 255). Default: 10.
429
+
430
+ Returns:
431
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
432
+ float32.
433
+ """
434
+ if gray_noise:
435
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
436
+ noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
437
+ else:
438
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
439
+ return noise
440
+
441
+
442
+ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False):
443
+ """Add Gaussian noise.
444
+
445
+ Args:
446
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
447
+ sigma (float): Noise scale (measured in range 255). Default: 10.
448
+
449
+ Returns:
450
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
451
+ float32.
452
+ """
453
+ noise = generate_gaussian_noise(img, sigma, gray_noise)
454
+ out = img + noise
455
+ if clip and rounds:
456
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
457
+ elif clip:
458
+ out = np.clip(out, 0, 1)
459
+ elif rounds:
460
+ out = (out * 255.0).round() / 255.
461
+ return out
462
+
463
+
464
+ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
465
+ """Add Gaussian noise (PyTorch version).
466
+
467
+ Args:
468
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
469
+ scale (float | Tensor): Noise scale. Default: 1.0.
470
+
471
+ Returns:
472
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
473
+ float32.
474
+ """
475
+ b, _, h, w = img.size()
476
+ if not isinstance(sigma, (float, int)):
477
+ sigma = sigma.view(img.size(0), 1, 1, 1)
478
+ if isinstance(gray_noise, (float, int)):
479
+ cal_gray_noise = gray_noise > 0
480
+ else:
481
+ gray_noise = gray_noise.view(b, 1, 1, 1)
482
+ cal_gray_noise = torch.sum(gray_noise) > 0
483
+
484
+ if cal_gray_noise:
485
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
486
+ noise_gray = noise_gray.view(b, 1, h, w)
487
+
488
+ # always calculate color noise
489
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
490
+
491
+ if cal_gray_noise:
492
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
493
+ return noise
494
+
495
+
496
+ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
497
+ """Add Gaussian noise (PyTorch version).
498
+
499
+ Args:
500
+ img (Tensor): Shape (b, c, h, w), range[0, 1], float32.
501
+ scale (float | Tensor): Noise scale. Default: 1.0.
502
+
503
+ Returns:
504
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
505
+ float32.
506
+ """
507
+ noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
508
+ out = img + noise
509
+ if clip and rounds:
510
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
511
+ elif clip:
512
+ out = torch.clamp(out, 0, 1)
513
+ elif rounds:
514
+ out = (out * 255.0).round() / 255.
515
+ return out
516
+
517
+
518
+ # ----------------------- Random Gaussian Noise ----------------------- #
519
+ def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0):
520
+ sigma = np.random.uniform(sigma_range[0], sigma_range[1])
521
+ if np.random.uniform() < gray_prob:
522
+ gray_noise = True
523
+ else:
524
+ gray_noise = False
525
+ return generate_gaussian_noise(img, sigma, gray_noise)
526
+
527
+
528
+ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
529
+ noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
530
+ out = img + noise
531
+ if clip and rounds:
532
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
533
+ elif clip:
534
+ out = np.clip(out, 0, 1)
535
+ elif rounds:
536
+ out = (out * 255.0).round() / 255.
537
+ return out
538
+
539
+
540
+ def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
541
+ sigma = torch.rand(
542
+ img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
543
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
544
+ gray_noise = (gray_noise < gray_prob).float()
545
+ return generate_gaussian_noise_pt(img, sigma, gray_noise)
546
+
547
+
548
+ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
549
+ noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
550
+ out = img + noise
551
+ if clip and rounds:
552
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
553
+ elif clip:
554
+ out = torch.clamp(out, 0, 1)
555
+ elif rounds:
556
+ out = (out * 255.0).round() / 255.
557
+ return out
558
+
559
+
560
+ # ----------------------- Poisson (Shot) Noise ----------------------- #
561
+
562
+
563
+ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
564
+ """Generate poisson noise.
565
+
566
+ Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219
567
+
568
+ Args:
569
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
570
+ scale (float): Noise scale. Default: 1.0.
571
+ gray_noise (bool): Whether generate gray noise. Default: False.
572
+
573
+ Returns:
574
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
575
+ float32.
576
+ """
577
+ if gray_noise:
578
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
579
+ # round and clip image for counting vals correctly
580
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
581
+ vals = len(np.unique(img))
582
+ vals = 2**np.ceil(np.log2(vals))
583
+ out = np.float32(np.random.poisson(img * vals) / float(vals))
584
+ noise = out - img
585
+ if gray_noise:
586
+ noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2)
587
+ return noise * scale
588
+
589
+
590
+ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False):
591
+ """Add poisson noise.
592
+
593
+ Args:
594
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
595
+ scale (float): Noise scale. Default: 1.0.
596
+ gray_noise (bool): Whether generate gray noise. Default: False.
597
+
598
+ Returns:
599
+ (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1],
600
+ float32.
601
+ """
602
+ noise = generate_poisson_noise(img, scale, gray_noise)
603
+ out = img + noise
604
+ if clip and rounds:
605
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
606
+ elif clip:
607
+ out = np.clip(out, 0, 1)
608
+ elif rounds:
609
+ out = (out * 255.0).round() / 255.
610
+ return out
611
+
612
+
613
+ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
614
+ """Generate a batch of poisson noise (PyTorch version)
615
+
616
+ Args:
617
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
618
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
619
+ Default: 1.0.
620
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
621
+ 0 for False, 1 for True. Default: 0.
622
+
623
+ Returns:
624
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
625
+ float32.
626
+ """
627
+ b, _, h, w = img.size()
628
+ if isinstance(gray_noise, (float, int)):
629
+ cal_gray_noise = gray_noise > 0
630
+ else:
631
+ gray_noise = gray_noise.view(b, 1, 1, 1)
632
+ cal_gray_noise = torch.sum(gray_noise) > 0
633
+ if cal_gray_noise:
634
+ img_gray = rgb_to_grayscale(img, num_output_channels=1)
635
+ # round and clip image for counting vals correctly
636
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
637
+ # use for-loop to get the unique values for each sample
638
+ vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
639
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
640
+ vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
641
+ out = torch.poisson(img_gray * vals) / vals
642
+ noise_gray = out - img_gray
643
+ noise_gray = noise_gray.expand(b, 3, h, w)
644
+
645
+ # always calculate color noise
646
+ # round and clip image for counting vals correctly
647
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
648
+ # use for-loop to get the unique values for each sample
649
+ vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
650
+ vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
651
+ vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
652
+ out = torch.poisson(img * vals) / vals
653
+ noise = out - img
654
+ if cal_gray_noise:
655
+ noise = noise * (1 - gray_noise) + noise_gray * gray_noise
656
+ if not isinstance(scale, (float, int)):
657
+ scale = scale.view(b, 1, 1, 1)
658
+ return noise * scale
659
+
660
+
661
+ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
662
+ """Add poisson noise to a batch of images (PyTorch version).
663
+
664
+ Args:
665
+ img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32.
666
+ scale (float | Tensor): Noise scale. Number or Tensor with shape (b).
667
+ Default: 1.0.
668
+ gray_noise (float | Tensor): 0-1 number or Tensor with shape (b).
669
+ 0 for False, 1 for True. Default: 0.
670
+
671
+ Returns:
672
+ (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1],
673
+ float32.
674
+ """
675
+ noise = generate_poisson_noise_pt(img, scale, gray_noise)
676
+ out = img + noise
677
+ if clip and rounds:
678
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
679
+ elif clip:
680
+ out = torch.clamp(out, 0, 1)
681
+ elif rounds:
682
+ out = (out * 255.0).round() / 255.
683
+ return out
684
+
685
+
686
+ # ----------------------- Random Poisson (Shot) Noise ----------------------- #
687
+
688
+
689
+ def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0):
690
+ scale = np.random.uniform(scale_range[0], scale_range[1])
691
+ if np.random.uniform() < gray_prob:
692
+ gray_noise = True
693
+ else:
694
+ gray_noise = False
695
+ return generate_poisson_noise(img, scale, gray_noise)
696
+
697
+
698
+ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
699
+ noise = random_generate_poisson_noise(img, scale_range, gray_prob)
700
+ out = img + noise
701
+ if clip and rounds:
702
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.
703
+ elif clip:
704
+ out = np.clip(out, 0, 1)
705
+ elif rounds:
706
+ out = (out * 255.0).round() / 255.
707
+ return out
708
+
709
+
710
+ def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
711
+ scale = torch.rand(
712
+ img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
713
+ gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
714
+ gray_noise = (gray_noise < gray_prob).float()
715
+ return generate_poisson_noise_pt(img, scale, gray_noise)
716
+
717
+
718
+ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False):
719
+ noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
720
+ out = img + noise
721
+ if clip and rounds:
722
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
723
+ elif clip:
724
+ out = torch.clamp(out, 0, 1)
725
+ elif rounds:
726
+ out = (out * 255.0).round() / 255.
727
+ return out
728
+
729
+
730
+ # ------------------------------------------------------------------------ #
731
+ # --------------------------- JPEG compression --------------------------- #
732
+ # ------------------------------------------------------------------------ #
733
+
734
+
735
+ def add_jpg_compression(img, quality=90):
736
+ """Add JPG compression artifacts.
737
+
738
+ Args:
739
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
740
+ quality (float): JPG compression quality. 0 for lowest quality, 100 for
741
+ best quality. Default: 90.
742
+
743
+ Returns:
744
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
745
+ float32.
746
+ """
747
+ img = np.clip(img, 0, 1)
748
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
749
+ _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
750
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.
751
+ return img
752
+
753
+
754
+ def random_add_jpg_compression(img, quality_range=(90, 100)):
755
+ """Randomly add JPG compression artifacts.
756
+
757
+ Args:
758
+ img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
759
+ quality_range (tuple[float] | list[float]): JPG compression quality
760
+ range. 0 for lowest quality, 100 for best quality.
761
+ Default: (90, 100).
762
+
763
+ Returns:
764
+ (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
765
+ float32.
766
+ """
767
+ quality = np.random.uniform(quality_range[0], quality_range[1])
768
+ return add_jpg_compression(img, quality)
custom_nodes/ComfyUI-ReActor/r_basicsr/data/ffhq_dataset.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import time
3
+ from os import path as osp
4
+ from torch.utils import data as data
5
+ from torchvision.transforms.functional import normalize
6
+
7
+ from r_basicsr.data.transforms import augment
8
+ from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
9
+ from r_basicsr.utils.registry import DATASET_REGISTRY
10
+
11
+
12
+ @DATASET_REGISTRY.register()
13
+ class FFHQDataset(data.Dataset):
14
+ """FFHQ dataset for StyleGAN.
15
+
16
+ Args:
17
+ opt (dict): Config for train datasets. It contains the following keys:
18
+ dataroot_gt (str): Data root path for gt.
19
+ io_backend (dict): IO backend type and other kwarg.
20
+ mean (list | tuple): Image mean.
21
+ std (list | tuple): Image std.
22
+ use_hflip (bool): Whether to horizontally flip.
23
+
24
+ """
25
+
26
+ def __init__(self, opt):
27
+ super(FFHQDataset, self).__init__()
28
+ self.opt = opt
29
+ # file client (io backend)
30
+ self.file_client = None
31
+ self.io_backend_opt = opt['io_backend']
32
+
33
+ self.gt_folder = opt['dataroot_gt']
34
+ self.mean = opt['mean']
35
+ self.std = opt['std']
36
+
37
+ if self.io_backend_opt['type'] == 'lmdb':
38
+ self.io_backend_opt['db_paths'] = self.gt_folder
39
+ if not self.gt_folder.endswith('.lmdb'):
40
+ raise ValueError("'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
41
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
42
+ self.paths = [line.split('.')[0] for line in fin]
43
+ else:
44
+ # FFHQ has 70000 images in total
45
+ self.paths = [osp.join(self.gt_folder, f'{v:08d}.png') for v in range(70000)]
46
+
47
+ def __getitem__(self, index):
48
+ if self.file_client is None:
49
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50
+
51
+ # load gt image
52
+ gt_path = self.paths[index]
53
+ # avoid errors caused by high latency in reading files
54
+ retry = 3
55
+ while retry > 0:
56
+ try:
57
+ img_bytes = self.file_client.get(gt_path)
58
+ except Exception as e:
59
+ logger = get_root_logger()
60
+ logger.warning(f'File client error: {e}, remaining retry times: {retry - 1}')
61
+ # change another file to read
62
+ index = random.randint(0, self.__len__())
63
+ gt_path = self.paths[index]
64
+ time.sleep(1) # sleep 1s for occasional server congestion
65
+ else:
66
+ break
67
+ finally:
68
+ retry -= 1
69
+ img_gt = imfrombytes(img_bytes, float32=True)
70
+
71
+ # random horizontal flip
72
+ img_gt = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False)
73
+ # BGR to RGB, HWC to CHW, numpy to tensor
74
+ img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
75
+ # normalize
76
+ normalize(img_gt, self.mean, self.std, inplace=True)
77
+ return {'gt': img_gt, 'gt_path': gt_path}
78
+
79
+ def __len__(self):
80
+ return len(self.paths)
custom_nodes/ComfyUI-ReActor/r_basicsr/data/paired_image_dataset.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils import data as data
2
+ from torchvision.transforms.functional import normalize
3
+
4
+ from r_basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5
+ from r_basicsr.data.transforms import augment, paired_random_crop
6
+ from r_basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor
7
+ from r_basicsr.utils.registry import DATASET_REGISTRY
8
+
9
+
10
+ @DATASET_REGISTRY.register()
11
+ class PairedImageDataset(data.Dataset):
12
+ """Paired image dataset for image restoration.
13
+
14
+ Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
15
+
16
+ There are three modes:
17
+ 1. 'lmdb': Use lmdb files.
18
+ If opt['io_backend'] == lmdb.
19
+ 2. 'meta_info_file': Use meta information file to generate paths.
20
+ If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
21
+ 3. 'folder': Scan folders to generate paths.
22
+ The rest.
23
+
24
+ Args:
25
+ opt (dict): Config for train datasets. It contains the following keys:
26
+ dataroot_gt (str): Data root path for gt.
27
+ dataroot_lq (str): Data root path for lq.
28
+ meta_info_file (str): Path for meta information file.
29
+ io_backend (dict): IO backend type and other kwarg.
30
+ filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
31
+ Default: '{}'.
32
+ gt_size (int): Cropped patched size for gt patches.
33
+ use_hflip (bool): Use horizontal flips.
34
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
35
+
36
+ scale (bool): Scale, which will be added automatically.
37
+ phase (str): 'train' or 'val'.
38
+ """
39
+
40
+ def __init__(self, opt):
41
+ super(PairedImageDataset, self).__init__()
42
+ self.opt = opt
43
+ # file client (io backend)
44
+ self.file_client = None
45
+ self.io_backend_opt = opt['io_backend']
46
+ self.mean = opt['mean'] if 'mean' in opt else None
47
+ self.std = opt['std'] if 'std' in opt else None
48
+
49
+ self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
50
+ if 'filename_tmpl' in opt:
51
+ self.filename_tmpl = opt['filename_tmpl']
52
+ else:
53
+ self.filename_tmpl = '{}'
54
+
55
+ if self.io_backend_opt['type'] == 'lmdb':
56
+ self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
57
+ self.io_backend_opt['client_keys'] = ['lq', 'gt']
58
+ self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
59
+ elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
60
+ self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
61
+ self.opt['meta_info_file'], self.filename_tmpl)
62
+ else:
63
+ self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
64
+
65
+ def __getitem__(self, index):
66
+ if self.file_client is None:
67
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
68
+
69
+ scale = self.opt['scale']
70
+
71
+ # Load gt and lq images. Dimension order: HWC; channel order: BGR;
72
+ # image range: [0, 1], float32.
73
+ gt_path = self.paths[index]['gt_path']
74
+ img_bytes = self.file_client.get(gt_path, 'gt')
75
+ img_gt = imfrombytes(img_bytes, float32=True)
76
+ lq_path = self.paths[index]['lq_path']
77
+ img_bytes = self.file_client.get(lq_path, 'lq')
78
+ img_lq = imfrombytes(img_bytes, float32=True)
79
+
80
+ # augmentation for training
81
+ if self.opt['phase'] == 'train':
82
+ gt_size = self.opt['gt_size']
83
+ # random crop
84
+ img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
85
+ # flip, rotation
86
+ img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
87
+
88
+ # color space transform
89
+ if 'color' in self.opt and self.opt['color'] == 'y':
90
+ img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None]
91
+ img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None]
92
+
93
+ # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
94
+ # TODO: It is better to update the datasets, rather than force to crop
95
+ if self.opt['phase'] != 'train':
96
+ img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
97
+
98
+ # BGR to RGB, HWC to CHW, numpy to tensor
99
+ img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
100
+ # normalize
101
+ if self.mean is not None or self.std is not None:
102
+ normalize(img_lq, self.mean, self.std, inplace=True)
103
+ normalize(img_gt, self.mean, self.std, inplace=True)
104
+
105
+ return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path}
106
+
107
+ def __len__(self):
108
+ return len(self.paths)
custom_nodes/ComfyUI-ReActor/r_basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
custom_nodes/ComfyUI-ReActor/r_basicsr/data/realesrgan_dataset.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ import torch
9
+ from torch.utils import data as data
10
+
11
+ from r_basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
12
+ from r_basicsr.data.transforms import augment
13
+ from r_basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
14
+ from r_basicsr.utils.registry import DATASET_REGISTRY
15
+
16
+
17
+ @DATASET_REGISTRY.register(suffix='basicsr')
18
+ class RealESRGANDataset(data.Dataset):
19
+ """Dataset used for Real-ESRGAN model:
20
+ Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
21
+
22
+ It loads gt (Ground-Truth) images, and augments them.
23
+ It also generates blur kernels and sinc kernels for generating low-quality images.
24
+ Note that the low-quality images are processed in tensors on GPUS for faster processing.
25
+
26
+ Args:
27
+ opt (dict): Config for train datasets. It contains the following keys:
28
+ dataroot_gt (str): Data root path for gt.
29
+ meta_info (str): Path for meta information file.
30
+ io_backend (dict): IO backend type and other kwarg.
31
+ use_hflip (bool): Use horizontal flips.
32
+ use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
33
+ Please see more options in the codes.
34
+ """
35
+
36
+ def __init__(self, opt):
37
+ super(RealESRGANDataset, self).__init__()
38
+ self.opt = opt
39
+ self.file_client = None
40
+ self.io_backend_opt = opt['io_backend']
41
+ self.gt_folder = opt['dataroot_gt']
42
+
43
+ # file client (lmdb io backend)
44
+ if self.io_backend_opt['type'] == 'lmdb':
45
+ self.io_backend_opt['db_paths'] = [self.gt_folder]
46
+ self.io_backend_opt['client_keys'] = ['gt']
47
+ if not self.gt_folder.endswith('.lmdb'):
48
+ raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}")
49
+ with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin:
50
+ self.paths = [line.split('.')[0] for line in fin]
51
+ else:
52
+ # disk backend with meta_info
53
+ # Each line in the meta_info describes the relative path to an image
54
+ with open(self.opt['meta_info']) as fin:
55
+ paths = [line.strip().split(' ')[0] for line in fin]
56
+ self.paths = [os.path.join(self.gt_folder, v) for v in paths]
57
+
58
+ # blur settings for the first degradation
59
+ self.blur_kernel_size = opt['blur_kernel_size']
60
+ self.kernel_list = opt['kernel_list']
61
+ self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability
62
+ self.blur_sigma = opt['blur_sigma']
63
+ self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels
64
+ self.betap_range = opt['betap_range'] # betap used in plateau blur kernels
65
+ self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters
66
+
67
+ # blur settings for the second degradation
68
+ self.blur_kernel_size2 = opt['blur_kernel_size2']
69
+ self.kernel_list2 = opt['kernel_list2']
70
+ self.kernel_prob2 = opt['kernel_prob2']
71
+ self.blur_sigma2 = opt['blur_sigma2']
72
+ self.betag_range2 = opt['betag_range2']
73
+ self.betap_range2 = opt['betap_range2']
74
+ self.sinc_prob2 = opt['sinc_prob2']
75
+
76
+ # a final sinc filter
77
+ self.final_sinc_prob = opt['final_sinc_prob']
78
+
79
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
80
+ # TODO: kernel range is now hard-coded, should be in the configure file
81
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
82
+ self.pulse_tensor[10, 10] = 1
83
+
84
+ def __getitem__(self, index):
85
+ if self.file_client is None:
86
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
87
+
88
+ # -------------------------------- Load gt images -------------------------------- #
89
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
90
+ gt_path = self.paths[index]
91
+ # avoid errors caused by high latency in reading files
92
+ retry = 3
93
+ while retry > 0:
94
+ try:
95
+ img_bytes = self.file_client.get(gt_path, 'gt')
96
+ except (IOError, OSError) as e:
97
+ logger = get_root_logger()
98
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
99
+ # change another file to read
100
+ index = random.randint(0, self.__len__())
101
+ gt_path = self.paths[index]
102
+ time.sleep(1) # sleep 1s for occasional server congestion
103
+ else:
104
+ break
105
+ finally:
106
+ retry -= 1
107
+ img_gt = imfrombytes(img_bytes, float32=True)
108
+
109
+ # -------------------- Do augmentation for training: flip, rotation -------------------- #
110
+ img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])
111
+
112
+ # crop or pad to 400
113
+ # TODO: 400 is hard-coded. You may change it accordingly
114
+ h, w = img_gt.shape[0:2]
115
+ crop_pad_size = 400
116
+ # pad
117
+ if h < crop_pad_size or w < crop_pad_size:
118
+ pad_h = max(0, crop_pad_size - h)
119
+ pad_w = max(0, crop_pad_size - w)
120
+ img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
121
+ # crop
122
+ if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size:
123
+ h, w = img_gt.shape[0:2]
124
+ # randomly choose top and left coordinates
125
+ top = random.randint(0, h - crop_pad_size)
126
+ left = random.randint(0, w - crop_pad_size)
127
+ img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
128
+
129
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
130
+ kernel_size = random.choice(self.kernel_range)
131
+ if np.random.uniform() < self.opt['sinc_prob']:
132
+ # this sinc filter setting is for kernels ranging from [7, 21]
133
+ if kernel_size < 13:
134
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
135
+ else:
136
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
137
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
138
+ else:
139
+ kernel = random_mixed_kernels(
140
+ self.kernel_list,
141
+ self.kernel_prob,
142
+ kernel_size,
143
+ self.blur_sigma,
144
+ self.blur_sigma, [-math.pi, math.pi],
145
+ self.betag_range,
146
+ self.betap_range,
147
+ noise_range=None)
148
+ # pad kernel
149
+ pad_size = (21 - kernel_size) // 2
150
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
151
+
152
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
153
+ kernel_size = random.choice(self.kernel_range)
154
+ if np.random.uniform() < self.opt['sinc_prob2']:
155
+ if kernel_size < 13:
156
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
157
+ else:
158
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
159
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
160
+ else:
161
+ kernel2 = random_mixed_kernels(
162
+ self.kernel_list2,
163
+ self.kernel_prob2,
164
+ kernel_size,
165
+ self.blur_sigma2,
166
+ self.blur_sigma2, [-math.pi, math.pi],
167
+ self.betag_range2,
168
+ self.betap_range2,
169
+ noise_range=None)
170
+
171
+ # pad kernel
172
+ pad_size = (21 - kernel_size) // 2
173
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
174
+
175
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
176
+ if np.random.uniform() < self.opt['final_sinc_prob']:
177
+ kernel_size = random.choice(self.kernel_range)
178
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
179
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
180
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
181
+ else:
182
+ sinc_kernel = self.pulse_tensor
183
+
184
+ # BGR to RGB, HWC to CHW, numpy to tensor
185
+ img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0]
186
+ kernel = torch.FloatTensor(kernel)
187
+ kernel2 = torch.FloatTensor(kernel2)
188
+
189
+ return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}
190
+ return return_d
191
+
192
+ def __len__(self):
193
+ return len(self.paths)