lopho commited on
Commit
149cc2d
·
1 Parent(s): cc65356

this isn't very nice.

Browse files
makeavid_sd/LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
makeavid_sd/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # makeavid-sd-tpu
makeavid_sd/makeavid_sd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.1.0'
makeavid_sd/makeavid_sd/flax_impl/__init__.py ADDED
File without changes
makeavid_sd/makeavid_sd/flax_impl/dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Dict, Any, Union, Optional
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader, ConcatDataset
6
+ import datasets
7
+ from diffusers import DDPMScheduler
8
+ from functools import partial
9
+ import random
10
+
11
+ import numpy as np
12
+
13
+
14
+ @torch.no_grad()
15
+ def collate_fn(
16
+ batch: List[Dict[str, Any]],
17
+ noise_scheduler: DDPMScheduler,
18
+ num_frames: int,
19
+ hint_spacing: Optional[int] = None,
20
+ as_numpy: bool = True
21
+ ) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
22
+ if hint_spacing is None or hint_spacing < 1:
23
+ hint_spacing = num_frames
24
+ if as_numpy:
25
+ dtype = np.float32
26
+ else:
27
+ dtype = torch.float32
28
+ prompts = []
29
+ videos = []
30
+ for s in batch:
31
+ # prompt
32
+ prompts.append(torch.tensor(s['prompt']).to(dtype = torch.float32))
33
+ # frames
34
+ frames = torch.tensor(s['video']).to(dtype = torch.float32)
35
+ max_frames = len(frames)
36
+ assert max_frames >= num_frames
37
+ video_slice = random.randint(0, max_frames - num_frames)
38
+ frames = frames[video_slice:video_slice + num_frames]
39
+ frames = frames.permute(1, 0, 2, 3) # f, c, h, w -> c, f, h, w
40
+ videos.append(frames)
41
+
42
+ encoder_hidden_states = torch.cat(prompts) # b, 77, 768
43
+
44
+ latents = torch.stack(videos) # b, c, f, h, w
45
+ latents = latents * 0.18215
46
+ hint_latents = latents[:, :, ::hint_spacing, :, :]
47
+ hint_latents = hint_latents.repeat_interleave(hint_spacing, 2)
48
+ #hint_latents = hint_latents[:, :, :num_frames-1, :, :]
49
+ #input_latents = latents[:, :, 1:, :, :]
50
+ input_latents = latents
51
+ noise = torch.randn_like(input_latents)
52
+ bsz = input_latents.shape[0]
53
+ timesteps = torch.randint(
54
+ 0,
55
+ noise_scheduler.config.num_train_timesteps,
56
+ (bsz,),
57
+ dtype = torch.int64
58
+ )
59
+ noisy_latents = noise_scheduler.add_noise(input_latents, noise, timesteps)
60
+ mask = torch.zeros([
61
+ noisy_latents.shape[0],
62
+ 1,
63
+ noisy_latents.shape[2],
64
+ noisy_latents.shape[3],
65
+ noisy_latents.shape[4]
66
+ ])
67
+ latent_model_input = torch.cat([noisy_latents, mask, hint_latents], dim = 1)
68
+
69
+ latent_model_input = latent_model_input.to(memory_format = torch.contiguous_format)
70
+ encoder_hidden_states = encoder_hidden_states.to(memory_format = torch.contiguous_format)
71
+ timesteps = timesteps.to(memory_format = torch.contiguous_format)
72
+ noise = noise.to(memory_format = torch.contiguous_format)
73
+
74
+ if as_numpy:
75
+ latent_model_input = latent_model_input.numpy().astype(dtype)
76
+ encoder_hidden_states = encoder_hidden_states.numpy().astype(dtype)
77
+ timesteps = timesteps.numpy().astype(np.int32)
78
+ noise = noise.numpy().astype(dtype)
79
+ else:
80
+ latent_model_input = latent_model_input.to(dtype = dtype)
81
+ encoder_hidden_states = encoder_hidden_states.to(dtype = dtype)
82
+ noise = noise.to(dtype = dtype)
83
+
84
+ return {
85
+ 'latent_model_input': latent_model_input,
86
+ 'encoder_hidden_states': encoder_hidden_states,
87
+ 'timesteps': timesteps,
88
+ 'noise': noise
89
+ }
90
+
91
+ def worker_init_fn(worker_id: int):
92
+ wseed = torch.initial_seed() % 4294967294 # max val for random 2**32 - 1
93
+ random.seed(wseed)
94
+ np.random.seed(wseed)
95
+
96
+
97
+ def load_dataset(
98
+ dataset_path: str,
99
+ model_path: str,
100
+ cache_dir: Optional[str] = None,
101
+ batch_size: int = 1,
102
+ num_frames: int = 24,
103
+ hint_spacing: Optional[int] = None,
104
+ num_workers: int = 0,
105
+ shuffle: bool = False,
106
+ as_numpy: bool = True,
107
+ pin_memory: bool = False,
108
+ pin_memory_device: str = ''
109
+ ) -> DataLoader:
110
+ noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
111
+ model_path,
112
+ subfolder = 'scheduler'
113
+ )
114
+ dataset = datasets.load_dataset(
115
+ dataset_path,
116
+ streaming = False,
117
+ cache_dir = cache_dir
118
+ )
119
+ merged_dataset = ConcatDataset([ dataset[s] for s in dataset ])
120
+ dataloader = DataLoader(
121
+ merged_dataset,
122
+ batch_size = batch_size,
123
+ num_workers = num_workers,
124
+ persistent_workers = num_workers > 0,
125
+ drop_last = True,
126
+ shuffle = shuffle,
127
+ worker_init_fn = worker_init_fn,
128
+ collate_fn = partial(collate_fn,
129
+ noise_scheduler = noise_scheduler,
130
+ num_frames = num_frames,
131
+ hint_spacing = hint_spacing,
132
+ as_numpy = as_numpy
133
+ ),
134
+ pin_memory = pin_memory,
135
+ pin_memory_device = pin_memory_device
136
+ )
137
+ return dataloader
138
+
139
+
140
+ def validate_dataset(
141
+ dataset_path: str
142
+ ) -> List[int]:
143
+ import os
144
+ import json
145
+ data_path = os.path.join(dataset_path, 'data')
146
+ meta = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'metadata')))
147
+ prompts = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'prompts')))
148
+ videos = set(os.path.splitext(x)[0] for x in os.listdir(os.path.join(data_path, 'videos')))
149
+ ok = meta.intersection(prompts).intersection(videos)
150
+ all_of_em = meta.union(prompts).union(videos)
151
+ not_ok = []
152
+ for a in all_of_em:
153
+ if a not in ok:
154
+ not_ok.append(a)
155
+ ok = list(ok)
156
+ ok.sort()
157
+ with open(os.path.join(data_path, 'id_list.json'), 'w') as f:
158
+ json.dump(ok, f)
159
+
makeavid_sd/makeavid_sd/flax_impl/flax_attention_pseudo3d.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ import einops
9
+
10
+ #from flax_memory_efficient_attention import jax_memory_efficient_attention
11
+ #from flax_attention import FlaxAttention
12
+ from diffusers.models.attention_flax import FlaxAttention
13
+
14
+
15
+ class TransformerPseudo3DModel(nn.Module):
16
+ in_channels: int
17
+ num_attention_heads: int
18
+ attention_head_dim: int
19
+ num_layers: int = 1
20
+ use_memory_efficient_attention: bool = False
21
+ dtype: jnp.dtype = jnp.float32
22
+
23
+ def setup(self) -> None:
24
+ inner_dim = self.num_attention_heads * self.attention_head_dim
25
+ self.norm = nn.GroupNorm(
26
+ num_groups = 32,
27
+ epsilon = 1e-5
28
+ )
29
+ self.proj_in = nn.Conv(
30
+ inner_dim,
31
+ kernel_size = (1, 1),
32
+ strides = (1, 1),
33
+ padding = 'VALID',
34
+ dtype = self.dtype
35
+ )
36
+ transformer_blocks = []
37
+ #CheckpointTransformerBlock = nn.checkpoint(
38
+ # BasicTransformerBlockPseudo3D,
39
+ # static_argnums = (2,3,4)
40
+ # #prevent_cse = False
41
+ #)
42
+ CheckpointTransformerBlock = BasicTransformerBlockPseudo3D
43
+ for _ in range(self.num_layers):
44
+ transformer_blocks.append(CheckpointTransformerBlock(
45
+ dim = inner_dim,
46
+ num_attention_heads = self.num_attention_heads,
47
+ attention_head_dim = self.attention_head_dim,
48
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
49
+ dtype = self.dtype
50
+ ))
51
+ self.transformer_blocks = transformer_blocks
52
+ self.proj_out = nn.Conv(
53
+ inner_dim,
54
+ kernel_size = (1, 1),
55
+ strides = (1, 1),
56
+ padding = 'VALID',
57
+ dtype = self.dtype
58
+ )
59
+
60
+ def __call__(self,
61
+ hidden_states: jax.Array,
62
+ encoder_hidden_states: Optional[jax.Array] = None
63
+ ) -> jax.Array:
64
+ is_video = hidden_states.ndim == 5
65
+ f: Optional[int] = None
66
+ if is_video:
67
+ # jax is channels last
68
+ # b,c,f,h,w WRONG
69
+ # b,f,h,w,c CORRECT
70
+ # b, c, f, h, w = hidden_states.shape
71
+ #hidden_states = einops.rearrange(hidden_states, 'b c f h w -> (b f) c h w')
72
+ b, f, h, w, c = hidden_states.shape
73
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
74
+
75
+ batch, height, width, channels = hidden_states.shape
76
+ residual = hidden_states
77
+ hidden_states = self.norm(hidden_states)
78
+ hidden_states = self.proj_in(hidden_states)
79
+ hidden_states = hidden_states.reshape(batch, height * width, channels)
80
+ for block in self.transformer_blocks:
81
+ hidden_states = block(
82
+ hidden_states,
83
+ encoder_hidden_states,
84
+ f,
85
+ height,
86
+ width
87
+ )
88
+ hidden_states = hidden_states.reshape(batch, height, width, channels)
89
+ hidden_states = self.proj_out(hidden_states)
90
+ hidden_states = hidden_states + residual
91
+ if is_video:
92
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
93
+ return hidden_states
94
+
95
+
96
+ class BasicTransformerBlockPseudo3D(nn.Module):
97
+ dim: int
98
+ num_attention_heads: int
99
+ attention_head_dim: int
100
+ use_memory_efficient_attention: bool = False
101
+ dtype: jnp.dtype = jnp.float32
102
+
103
+ def setup(self) -> None:
104
+ self.attn1 = FlaxAttention(
105
+ query_dim = self.dim,
106
+ heads = self.num_attention_heads,
107
+ dim_head = self.attention_head_dim,
108
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
109
+ dtype = self.dtype
110
+ )
111
+ self.ff = FeedForward(dim = self.dim, dtype = self.dtype)
112
+ self.attn2 = FlaxAttention(
113
+ query_dim = self.dim,
114
+ heads = self.num_attention_heads,
115
+ dim_head = self.attention_head_dim,
116
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
117
+ dtype = self.dtype
118
+ )
119
+ self.attn_temporal = FlaxAttention(
120
+ query_dim = self.dim,
121
+ heads = self.num_attention_heads,
122
+ dim_head = self.attention_head_dim,
123
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
124
+ dtype = self.dtype
125
+ )
126
+ self.norm1 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
127
+ self.norm2 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
128
+ self.norm_temporal = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
129
+ self.norm3 = nn.LayerNorm(epsilon = 1e-5, dtype = self.dtype)
130
+
131
+ def __call__(self,
132
+ hidden_states: jax.Array,
133
+ context: Optional[jax.Array] = None,
134
+ frames_length: Optional[int] = None,
135
+ height: Optional[int] = None,
136
+ width: Optional[int] = None
137
+ ) -> jax.Array:
138
+ if context is not None and frames_length is not None:
139
+ context = context.repeat(frames_length, axis = 0)
140
+ # self attention
141
+ norm_hidden_states = self.norm1(hidden_states)
142
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
143
+ # cross attention
144
+ norm_hidden_states = self.norm2(hidden_states)
145
+ hidden_states = self.attn2(
146
+ norm_hidden_states,
147
+ context = context
148
+ ) + hidden_states
149
+ # temporal attention
150
+ if frames_length is not None:
151
+ #bf, hw, c = hidden_states.shape
152
+ # (b f) (h w) c -> b f (h w) c
153
+ #hidden_states = hidden_states.reshape(bf // frames_length, frames_length, hw, c)
154
+ #b, f, hw, c = hidden_states.shape
155
+ # b f (h w) c -> b (h w) f c
156
+ #hidden_states = hidden_states.transpose(0, 2, 1, 3)
157
+ # b (h w) f c -> (b h w) f c
158
+ #hidden_states = hidden_states.reshape(b * hw, frames_length, c)
159
+ hidden_states = einops.rearrange(
160
+ hidden_states,
161
+ '(b f) (h w) c -> (b h w) f c',
162
+ f = frames_length,
163
+ h = height,
164
+ w = width
165
+ )
166
+ norm_hidden_states = self.norm_temporal(hidden_states)
167
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
168
+ # (b h w) f c -> b (h w) f c
169
+ #hidden_states = hidden_states.reshape(b, hw, f, c)
170
+ # b (h w) f c -> b f (h w) c
171
+ #hidden_states = hidden_states.transpose(0, 2, 1, 3)
172
+ # b f h w c -> (b f) (h w) c
173
+ #hidden_states = hidden_states.reshape(bf, hw, c)
174
+ hidden_states = einops.rearrange(
175
+ hidden_states,
176
+ '(b h w) f c -> (b f) (h w) c',
177
+ f = frames_length,
178
+ h = height,
179
+ w = width
180
+ )
181
+ norm_hidden_states = self.norm3(hidden_states)
182
+ hidden_states = self.ff(norm_hidden_states) + hidden_states
183
+ return hidden_states
184
+
185
+
186
+ class FeedForward(nn.Module):
187
+ dim: int
188
+ dtype: jnp.dtype = jnp.float32
189
+
190
+ def setup(self) -> None:
191
+ self.net_0 = GEGLU(self.dim, self.dtype)
192
+ self.net_2 = nn.Dense(self.dim, dtype = self.dtype)
193
+
194
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
195
+ hidden_states = self.net_0(hidden_states)
196
+ hidden_states = self.net_2(hidden_states)
197
+ return hidden_states
198
+
199
+
200
+ class GEGLU(nn.Module):
201
+ dim: int
202
+ dtype: jnp.dtype = jnp.float32
203
+
204
+ def setup(self) -> None:
205
+ inner_dim = self.dim * 4
206
+ self.proj = nn.Dense(inner_dim * 2, dtype = self.dtype)
207
+
208
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
209
+ hidden_states = self.proj(hidden_states)
210
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis = 2)
211
+ return hidden_linear * nn.gelu(hidden_gelu)
212
+
makeavid_sd/makeavid_sd/flax_impl/flax_embeddings.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import flax.linen as nn
5
+
6
+
7
+ def get_sinusoidal_embeddings(
8
+ timesteps: jax.Array,
9
+ embedding_dim: int,
10
+ freq_shift: float = 1,
11
+ min_timescale: float = 1,
12
+ max_timescale: float = 1.0e4,
13
+ flip_sin_to_cos: bool = False,
14
+ scale: float = 1.0,
15
+ dtype: jnp.dtype = jnp.float32
16
+ ) -> jax.Array:
17
+ assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
18
+ assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
19
+ num_timescales = float(embedding_dim // 2)
20
+ log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
21
+ inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment)
22
+ emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
23
+
24
+ # scale embeddings
25
+ scaled_time = scale * emb
26
+
27
+ if flip_sin_to_cos:
28
+ signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1)
29
+ else:
30
+ signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1)
31
+ signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
32
+ return signal
33
+
34
+
35
+ class TimestepEmbedding(nn.Module):
36
+ time_embed_dim: int = 32
37
+ dtype: jnp.dtype = jnp.float32
38
+
39
+ @nn.compact
40
+ def __call__(self, temb: jax.Array) -> jax.Array:
41
+ temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb)
42
+ temb = nn.silu(temb)
43
+ temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb)
44
+ return temb
45
+
46
+
47
+ class Timesteps(nn.Module):
48
+ dim: int = 32
49
+ flip_sin_to_cos: bool = False
50
+ freq_shift: float = 1
51
+ dtype: jnp.dtype = jnp.float32
52
+
53
+ @nn.compact
54
+ def __call__(self, timesteps: jax.Array) -> jax.Array:
55
+ return get_sinusoidal_embeddings(
56
+ timesteps = timesteps,
57
+ embedding_dim = self.dim,
58
+ flip_sin_to_cos = self.flip_sin_to_cos,
59
+ freq_shift = self.freq_shift,
60
+ dtype = self.dtype
61
+ )
62
+
makeavid_sd/makeavid_sd/flax_impl/flax_resnet_pseudo3d.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Optional, Union, Sequence
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ import einops
9
+
10
+
11
+ class ConvPseudo3D(nn.Module):
12
+ features: int
13
+ kernel_size: Sequence[int]
14
+ strides: Union[None, int, Sequence[int]] = 1
15
+ padding: nn.linear.PaddingLike = 'SAME'
16
+ dtype: jnp.dtype = jnp.float32
17
+
18
+ def setup(self) -> None:
19
+ self.spatial_conv = nn.Conv(
20
+ features = self.features,
21
+ kernel_size = self.kernel_size,
22
+ strides = self.strides,
23
+ padding = self.padding,
24
+ dtype = self.dtype
25
+ )
26
+ self.temporal_conv = nn.Conv(
27
+ features = self.features,
28
+ kernel_size = (3,),
29
+ padding = 'SAME',
30
+ dtype = self.dtype,
31
+ bias_init = nn.initializers.zeros_init()
32
+ # TODO dirac delta (identity) initialization impl
33
+ # kernel_init = torch.nn.init.dirac_ <-> jax/lax
34
+ )
35
+
36
+ def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
37
+ is_video = x.ndim == 5
38
+ convolve_across_time = convolve_across_time and is_video
39
+ if is_video:
40
+ b, f, h, w, c = x.shape
41
+ x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
42
+ x = self.spatial_conv(x)
43
+ if is_video:
44
+ x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
45
+ b, f, h, w, c = x.shape
46
+ if not convolve_across_time:
47
+ return x
48
+ if is_video:
49
+ x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
50
+ x = self.temporal_conv(x)
51
+ x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
52
+ return x
53
+
54
+
55
+ class UpsamplePseudo3D(nn.Module):
56
+ out_channels: int
57
+ dtype: jnp.dtype = jnp.float32
58
+
59
+ def setup(self) -> None:
60
+ self.conv = ConvPseudo3D(
61
+ features = self.out_channels,
62
+ kernel_size = (3, 3),
63
+ strides = (1, 1),
64
+ padding = ((1, 1), (1, 1)),
65
+ dtype = self.dtype
66
+ )
67
+
68
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
69
+ is_video = hidden_states.ndim == 5
70
+ if is_video:
71
+ b, *_ = hidden_states.shape
72
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
73
+ batch, h, w, c = hidden_states.shape
74
+ hidden_states = jax.image.resize(
75
+ image = hidden_states,
76
+ shape = (batch, h * 2, w * 2, c),
77
+ method = 'nearest'
78
+ )
79
+ if is_video:
80
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
81
+ hidden_states = self.conv(hidden_states)
82
+ return hidden_states
83
+
84
+
85
+ class DownsamplePseudo3D(nn.Module):
86
+ out_channels: int
87
+ dtype: jnp.dtype = jnp.float32
88
+
89
+ def setup(self) -> None:
90
+ self.conv = ConvPseudo3D(
91
+ features = self.out_channels,
92
+ kernel_size = (3, 3),
93
+ strides = (2, 2),
94
+ padding = ((1, 1), (1, 1)),
95
+ dtype = self.dtype
96
+ )
97
+
98
+ def __call__(self, hidden_states: jax.Array) -> jax.Array:
99
+ hidden_states = self.conv(hidden_states)
100
+ return hidden_states
101
+
102
+
103
+ class ResnetBlockPseudo3D(nn.Module):
104
+ in_channels: int
105
+ out_channels: Optional[int] = None
106
+ use_nin_shortcut: Optional[bool] = None
107
+ dtype: jnp.dtype = jnp.float32
108
+
109
+ def setup(self) -> None:
110
+ out_channels = self.in_channels if self.out_channels is None else self.out_channels
111
+ self.norm1 = nn.GroupNorm(
112
+ num_groups = 32,
113
+ epsilon = 1e-5
114
+ )
115
+ self.conv1 = ConvPseudo3D(
116
+ features = out_channels,
117
+ kernel_size = (3, 3),
118
+ strides = (1, 1),
119
+ padding = ((1, 1), (1, 1)),
120
+ dtype = self.dtype
121
+ )
122
+ self.time_emb_proj = nn.Dense(
123
+ out_channels,
124
+ dtype = self.dtype
125
+ )
126
+ self.norm2 = nn.GroupNorm(
127
+ num_groups = 32,
128
+ epsilon = 1e-5
129
+ )
130
+ self.conv2 = ConvPseudo3D(
131
+ features = out_channels,
132
+ kernel_size = (3, 3),
133
+ strides = (1, 1),
134
+ padding = ((1, 1), (1, 1)),
135
+ dtype = self.dtype
136
+ )
137
+ use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
138
+ self.conv_shortcut = None
139
+ if use_nin_shortcut:
140
+ self.conv_shortcut = ConvPseudo3D(
141
+ features = self.out_channels,
142
+ kernel_size = (1, 1),
143
+ strides = (1, 1),
144
+ padding = 'VALID',
145
+ dtype = self.dtype
146
+ )
147
+
148
+ def __call__(self,
149
+ hidden_states: jax.Array,
150
+ temb: jax.Array
151
+ ) -> jax.Array:
152
+ is_video = hidden_states.ndim == 5
153
+ residual = hidden_states
154
+ hidden_states = self.norm1(hidden_states)
155
+ hidden_states = nn.silu(hidden_states)
156
+ hidden_states = self.conv1(hidden_states)
157
+ temb = nn.silu(temb)
158
+ temb = self.time_emb_proj(temb)
159
+ temb = jnp.expand_dims(temb, 1)
160
+ temb = jnp.expand_dims(temb, 1)
161
+ if is_video:
162
+ b, f, *_ = hidden_states.shape
163
+ hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
164
+ hidden_states = hidden_states + temb.repeat(f, 0)
165
+ hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
166
+ else:
167
+ hidden_states = hidden_states + temb
168
+ hidden_states = self.norm2(hidden_states)
169
+ hidden_states = nn.silu(hidden_states)
170
+ hidden_states = self.conv2(hidden_states)
171
+ if self.conv_shortcut is not None:
172
+ residual = self.conv_shortcut(residual)
173
+ hidden_states = hidden_states + residual
174
+ return hidden_states
175
+
makeavid_sd/makeavid_sd/flax_impl/flax_trainer.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Optional, Union, Tuple, Dict, List
3
+
4
+ import os
5
+ import random
6
+ import math
7
+ import time
8
+ import numpy as np
9
+ from tqdm.auto import tqdm, trange
10
+
11
+ import torch
12
+ from torch.utils.data import DataLoader
13
+
14
+ import jax
15
+ import jax.numpy as jnp
16
+ import optax
17
+ from flax import jax_utils, traverse_util
18
+ from flax.core.frozen_dict import FrozenDict
19
+ from flax.training.train_state import TrainState
20
+ from flax.training.common_utils import shard
21
+
22
+ # convert 2D -> 3D
23
+ from diffusers import FlaxUNet2DConditionModel
24
+
25
+ # inference test, run on these on cpu
26
+ from diffusers import AutoencoderKL
27
+ from diffusers.schedulers.scheduling_ddim_flax import FlaxDDIMScheduler, DDIMSchedulerState
28
+ from transformers import CLIPTextModel, CLIPTokenizer
29
+ from PIL import Image
30
+
31
+
32
+ from .flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
33
+
34
+
35
+ def seed_all(seed: int) -> jax.random.PRNGKeyArray:
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ rng = jax.random.PRNGKey(seed)
40
+ return rng
41
+
42
+ def count_params(
43
+ params: Union[Dict[str, Any],
44
+ FrozenDict[str, Any]],
45
+ filter_name: Optional[str] = None
46
+ ) -> int:
47
+ p: Dict[Tuple[str], jax.Array] = traverse_util.flatten_dict(params)
48
+ cc = 0
49
+ for k in p:
50
+ if filter_name is not None:
51
+ if filter_name in ' '.join(k):
52
+ cc += len(p[k].flatten())
53
+ else:
54
+ cc += len(p[k].flatten())
55
+ return cc
56
+
57
+ def map_2d_to_pseudo3d(
58
+ params2d: Dict[str, Any],
59
+ params3d: Dict[str, Any],
60
+ verbose: bool = True
61
+ ) -> Dict[str, Any]:
62
+ params2d = traverse_util.flatten_dict(params2d)
63
+ params3d = traverse_util.flatten_dict(params3d)
64
+ new_params = dict()
65
+ for k in params3d:
66
+ if 'spatial_conv' in k:
67
+ k2d = list(k)
68
+ k2d.remove('spatial_conv')
69
+ k2d = tuple(k2d)
70
+ if verbose:
71
+ tqdm.write(f'Spatial: {k} <- {k2d}')
72
+ p = params2d[k2d]
73
+ elif k not in params2d:
74
+ if verbose:
75
+ tqdm.write(f'Missing: {k}')
76
+ p = params3d[k]
77
+ else:
78
+ p = params2d[k]
79
+ assert p.shape == params3d[k].shape, f'shape mismatch: {k}: {p.shape} != {params3d[k].shape}'
80
+ new_params[k] = p
81
+ new_params = traverse_util.unflatten_dict(new_params)
82
+ return new_params
83
+
84
+
85
+ class FlaxTrainerUNetPseudo3D:
86
+ def __init__(self,
87
+ model_path: str,
88
+ from_pt: bool = True,
89
+ convert2d: bool = False,
90
+ sample_size: Tuple[int, int] = (64, 64),
91
+ seed: int = 0,
92
+ dtype: str = 'float32',
93
+ param_dtype: str = 'float32',
94
+ only_temporal: bool = True,
95
+ use_memory_efficient_attention = False,
96
+ verbose: bool = True
97
+ ) -> None:
98
+ self.verbose = verbose
99
+ self.tracker: Optional['wandb.sdk.wandb_run.Run'] = None
100
+ self._use_wandb: bool = False
101
+ self._tracker_meta: Dict[str, Union[float, int]] = {
102
+ 't00': 0.0,
103
+ 't0': 0.0,
104
+ 'step0': 0
105
+ }
106
+
107
+ self.log('Init JAX')
108
+ self.num_devices = jax.device_count()
109
+ self.log(f'Device count: {self.num_devices}')
110
+
111
+ self.seed = seed
112
+ self.rng: jax.random.PRNGKeyArray = seed_all(self.seed)
113
+
114
+ self.sample_size = sample_size
115
+ if dtype == 'float32':
116
+ self.dtype = jnp.float32
117
+ elif dtype == 'bfloat16':
118
+ self.dtype = jnp.bfloat16
119
+ elif dtype == 'float16':
120
+ self.dtype = jnp.float16
121
+ else:
122
+ raise ValueError(f'unknown type: {dtype}')
123
+ self.dtype_str: str = dtype
124
+ if param_dtype not in ['float32', 'bfloat16', 'float16']:
125
+ raise ValueError(f'unknown parameter type: {param_dtype}')
126
+ self.param_dtype = param_dtype
127
+ self._load_models(
128
+ model_path = model_path,
129
+ convert2d = convert2d,
130
+ from_pt = from_pt,
131
+ use_memory_efficient_attention = use_memory_efficient_attention
132
+ )
133
+ self._mark_parameters(only_temporal = only_temporal)
134
+ # optionally for validation + sampling
135
+ self.tokenizer: Optional[CLIPTokenizer] = None
136
+ self.text_encoder: Optional[CLIPTextModel] = None
137
+ self.vae: Optional[AutoencoderKL] = None
138
+ self.ddim: Optional[Tuple[FlaxDDIMScheduler, DDIMSchedulerState]] = None
139
+
140
+ def log(self, message: Any) -> None:
141
+ if self.verbose and jax.process_index() == 0:
142
+ tqdm.write(str(message))
143
+
144
+ def log_metrics(self, metrics: dict, step: int, epoch: int) -> None:
145
+ if jax.process_index() > 0 or (not self.verbose and self.tracker is None):
146
+ return
147
+ now = time.monotonic()
148
+ log_data = {
149
+ 'train/step': step,
150
+ 'train/epoch': epoch,
151
+ 'train/steps_per_sec': (step - self._tracker_meta['step0']) / (now - self._tracker_meta['t0']),
152
+ **{ f'train/{k}': v for k, v in metrics.items() }
153
+ }
154
+ self._tracker_meta['t0'] = now
155
+ self._tracker_meta['step0'] = step
156
+ self.log(log_data)
157
+ if self.tracker is not None:
158
+ self.tracker.log(log_data, step = step)
159
+
160
+
161
+ def enable_wandb(self, enable: bool = True) -> None:
162
+ self._use_wandb = enable
163
+
164
+ def _setup_wandb(self, config: Dict[str, Any] = dict()) -> None:
165
+ import wandb
166
+ import wandb.sdk
167
+ self.tracker: wandb.sdk.wandb_run.Run = wandb.init(
168
+ config = config,
169
+ settings = wandb.sdk.Settings(
170
+ username = 'anon',
171
+ host = 'anon',
172
+ email = 'anon',
173
+ root_dir = 'anon',
174
+ _executable = 'anon',
175
+ _disable_stats = True,
176
+ _disable_meta = True,
177
+ disable_code = True,
178
+ disable_git = True
179
+ ) # pls don't log sensitive data like system user names. also, fuck you for even trying.
180
+ )
181
+
182
+ def _init_tracker_meta(self) -> None:
183
+ now = time.monotonic()
184
+ self._tracker_meta = {
185
+ 't00': now,
186
+ 't0': now,
187
+ 'step0': 0
188
+ }
189
+
190
+ def _load_models(self,
191
+ model_path: str,
192
+ convert2d: bool,
193
+ from_pt: bool,
194
+ use_memory_efficient_attention: bool
195
+ ) -> None:
196
+ self.log(f'Load pretrained from {model_path}')
197
+ if convert2d:
198
+ self.log(' Convert 2D model to Pseudo3D')
199
+ self.log(' Initiate Pseudo3D model')
200
+ config = UNetPseudo3DConditionModel.load_config(model_path, subfolder = 'unet')
201
+ model = UNetPseudo3DConditionModel.from_config(
202
+ config,
203
+ sample_size = self.sample_size,
204
+ dtype = self.dtype,
205
+ param_dtype = self.param_dtype,
206
+ use_memory_efficient_attention = use_memory_efficient_attention
207
+ )
208
+ params: Dict[str, Any] = model.init_weights(self.rng).unfreeze()
209
+ self.log(' Load 2D model')
210
+ model2d, params2d = FlaxUNet2DConditionModel.from_pretrained(
211
+ model_path,
212
+ subfolder = 'unet',
213
+ dtype = self.dtype,
214
+ from_pt = from_pt
215
+ )
216
+ self.log(' Map 2D -> 3D')
217
+ params = map_2d_to_pseudo3d(params2d, params, verbose = self.verbose)
218
+ del params2d
219
+ del model2d
220
+ del config
221
+ else:
222
+ model, params = UNetPseudo3DConditionModel.from_pretrained(
223
+ model_path,
224
+ subfolder = 'unet',
225
+ from_pt = from_pt,
226
+ sample_size = self.sample_size,
227
+ dtype = self.dtype,
228
+ param_dtype = self.param_dtype,
229
+ use_memory_efficient_attention = use_memory_efficient_attention
230
+ )
231
+ self.log(f'Cast parameters to {model.param_dtype}')
232
+ if model.param_dtype == 'float32':
233
+ params = model.to_fp32(params)
234
+ elif model.param_dtype == 'float16':
235
+ params = model.to_fp16(params)
236
+ elif model.param_dtype == 'bfloat16':
237
+ params = model.to_bf16(params)
238
+ self.pretrained_model = model_path
239
+ self.model: UNetPseudo3DConditionModel = model
240
+ self.params: FrozenDict[str, Any] = FrozenDict(params)
241
+
242
+ def _mark_parameters(self, only_temporal: bool) -> None:
243
+ self.log('Mark training parameters')
244
+ if only_temporal:
245
+ self.log('Only training temporal layers')
246
+ if only_temporal:
247
+ param_partitions = traverse_util.path_aware_map(
248
+ lambda path, _: 'trainable' if 'temporal' in ' '.join(path) else 'frozen', self.params
249
+ )
250
+ else:
251
+ param_partitions = traverse_util.path_aware_map(
252
+ lambda *_: 'trainable', self.params
253
+ )
254
+ self.only_temporal = only_temporal
255
+ self.param_partitions: FrozenDict[str, Any] = FrozenDict(param_partitions)
256
+ self.log(f'Total parameters: {count_params(self.params)}')
257
+ self.log(f'Temporal parameters: {count_params(self.params, "temporal")}')
258
+
259
+ def _load_inference_models(self) -> None:
260
+ assert jax.process_index() == 0, 'not main process'
261
+ if self.text_encoder is None:
262
+ self.log('Load text encoder')
263
+ self.text_encoder = CLIPTextModel.from_pretrained(
264
+ self.pretrained_model,
265
+ subfolder = 'text_encoder'
266
+ )
267
+ if self.tokenizer is None:
268
+ self.log('Load tokenizer')
269
+ self.tokenizer = CLIPTokenizer.from_pretrained(
270
+ self.pretrained_model,
271
+ subfolder = 'tokenizer'
272
+ )
273
+ if self.vae is None:
274
+ self.log('Load vae')
275
+ self.vae = AutoencoderKL.from_pretrained(
276
+ self.pretrained_model,
277
+ subfolder = 'vae'
278
+ )
279
+ if self.ddim is None:
280
+ self.log('Load ddim scheduler')
281
+ # tuple(scheduler , scheduler state)
282
+ self.ddim = FlaxDDIMScheduler.from_pretrained(
283
+ self.pretrained_model,
284
+ subfolder = 'scheduler',
285
+ from_pt = True
286
+ )
287
+
288
+ def _unload_inference_models(self) -> None:
289
+ self.text_encoder = None
290
+ self.tokenizer = None
291
+ self.vae = None
292
+ self.ddim = None
293
+
294
+ def sample(self,
295
+ params: Union[Dict[str, Any], FrozenDict[str, Any]],
296
+ prompt: str,
297
+ image_path: str,
298
+ num_frames: int,
299
+ replicate_params: bool = True,
300
+ neg_prompt: str = '',
301
+ steps: int = 50,
302
+ cfg: float = 9.0,
303
+ unload_after_usage: bool = False
304
+ ) -> List[Image.Image]:
305
+ assert jax.process_index() == 0, 'not main process'
306
+ self.log('Sample')
307
+ self._load_inference_models()
308
+ with torch.no_grad():
309
+ tokens = self.tokenizer(
310
+ [ prompt ],
311
+ truncation = True,
312
+ return_overflowing_tokens = False,
313
+ padding = 'max_length',
314
+ return_tensors = 'pt'
315
+ ).input_ids
316
+ neg_tokens = self.tokenizer(
317
+ [ neg_prompt ],
318
+ truncation = True,
319
+ return_overflowing_tokens = False,
320
+ padding = 'max_length',
321
+ return_tensors = 'pt'
322
+ ).input_ids
323
+ encoded_prompt = self.text_encoder(input_ids = tokens).last_hidden_state
324
+ encoded_neg_prompt = self.text_encoder(input_ids = neg_tokens).last_hidden_state
325
+ hint_latent = torch.tensor(np.asarray(Image.open(image_path))).permute(2,0,1).to(torch.float32).div(255).mul(2).sub(1).unsqueeze(0)
326
+ hint_latent = self.vae.encode(hint_latent).latent_dist.mean * self.vae.config.scaling_factor #0.18215 # deterministic
327
+ hint_latent = hint_latent.unsqueeze(2).repeat_interleave(num_frames, 2)
328
+ mask = torch.zeros_like(hint_latent[:,0:1,:,:,:]) # zero mask, e.g. skip masking for now
329
+ init_latent = torch.randn_like(hint_latent)
330
+ # move to devices
331
+ encoded_prompt = jnp.array(encoded_prompt.numpy())
332
+ encoded_neg_prompt = jnp.array(encoded_neg_prompt.numpy())
333
+ hint_latent = jnp.array(hint_latent.numpy())
334
+ mask = jnp.array(mask.numpy())
335
+ init_latent = init_latent.repeat(jax.device_count(), 1, 1, 1, 1)
336
+ init_latent = jnp.array(init_latent.numpy())
337
+ self.ddim = (self.ddim[0], self.ddim[0].set_timesteps(self.ddim[1], steps))
338
+ timesteps = self.ddim[1].timesteps
339
+ if replicate_params:
340
+ params = jax_utils.replicate(params)
341
+ ddim_state = jax_utils.replicate(self.ddim[1])
342
+ encoded_prompt = jax_utils.replicate(encoded_prompt)
343
+ encoded_neg_prompt = jax_utils.replicate(encoded_neg_prompt)
344
+ hint_latent = jax_utils.replicate(hint_latent)
345
+ mask = jax_utils.replicate(mask)
346
+ # sampling fun
347
+ def sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask):
348
+ latent_model_input = jnp.concatenate([init_latent, mask, hint_latent], axis = 1)
349
+ pred = self.model.apply(
350
+ { 'params': params },
351
+ latent_model_input,
352
+ t,
353
+ encoded_prompt
354
+ ).sample
355
+ if cfg != 1.0:
356
+ neg_pred = self.model.apply(
357
+ { 'params': params },
358
+ latent_model_input,
359
+ t,
360
+ encoded_neg_prompt
361
+ ).sample
362
+ pred = neg_pred + cfg * (pred - neg_pred)
363
+ # TODO check if noise is added at the right dimension
364
+ init_latent, ddim_state = self.ddim[0].step(ddim_state, pred, t, init_latent).to_tuple()
365
+ return init_latent, ddim_state
366
+ p_sample_loop = jax.pmap(sample_loop, 'sample', donate_argnums = ())
367
+ pbar_sample = trange(len(timesteps), desc = 'Sample', dynamic_ncols = True, smoothing = 0.1, disable = not self.verbose)
368
+ init_latent = shard(init_latent)
369
+ for i in pbar_sample:
370
+ t = timesteps[i].repeat(self.num_devices)
371
+ t = shard(t)
372
+ init_latent, ddim_state = p_sample_loop(init_latent, ddim_state, t, params, encoded_prompt, encoded_neg_prompt, hint_latent, mask)
373
+ # decode
374
+ self.log('Decode')
375
+ init_latent = torch.tensor(np.array(init_latent))
376
+ init_latent = init_latent / self.vae.config.scaling_factor
377
+ # d:0 b:1 c:2 f:3 h:4 w:5 -> d b f c h w
378
+ init_latent = init_latent.permute(0, 1, 3, 2, 4, 5)
379
+ images = []
380
+ pbar_decode = trange(len(init_latent), desc = 'Decode', dynamic_ncols = True)
381
+ for sample in init_latent:
382
+ ims = self.vae.decode(sample.squeeze()).sample
383
+ ims = ims.add(1).div(2).mul(255).round().clamp(0, 255).to(torch.uint8).permute(0,2,3,1).numpy()
384
+ ims = [ Image.fromarray(x) for x in ims ]
385
+ for im in ims:
386
+ images.append(im)
387
+ pbar_decode.update(1)
388
+ if unload_after_usage:
389
+ self._unload_inference_models()
390
+ return images
391
+
392
+ def get_params_from_state(self, state: TrainState) -> FrozenDict[Any, str]:
393
+ return FrozenDict(jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)))
394
+
395
+ def train(self,
396
+ dataloader: DataLoader,
397
+ lr: float,
398
+ num_frames: int,
399
+ log_every_step: int = 10,
400
+ save_every_epoch: int = 1,
401
+ sample_every_epoch: int = 1,
402
+ output_dir: str = 'output',
403
+ warmup: float = 0,
404
+ decay: float = 0,
405
+ epochs: int = 10,
406
+ weight_decay: float = 1e-2
407
+ ) -> None:
408
+ eps = 1e-8
409
+ total_steps = len(dataloader) * epochs
410
+ warmup_steps = math.ceil(warmup * total_steps) if warmup > 0 else 0
411
+ decay_steps = math.ceil(decay * total_steps) + warmup_steps if decay > 0 else warmup_steps + 1
412
+ self.log(f'Total steps: {total_steps}')
413
+ self.log(f'Warmup steps: {warmup_steps}')
414
+ self.log(f'Decay steps: {decay_steps - warmup_steps}')
415
+ if warmup > 0 or decay > 0:
416
+ if not decay > 0:
417
+ # only warmup, keep peak lr until end
418
+ self.log('Warmup schedule')
419
+ end_lr = lr
420
+ else:
421
+ # warmup + annealing to end lr
422
+ self.log('Warmup + cosine annealing schedule')
423
+ end_lr = eps
424
+ lr_schedule = optax.warmup_cosine_decay_schedule(
425
+ init_value = 0.0,
426
+ peak_value = lr,
427
+ warmup_steps = warmup_steps,
428
+ decay_steps = decay_steps,
429
+ end_value = end_lr
430
+ )
431
+ else:
432
+ # no warmup or decay -> constant lr
433
+ self.log('constant schedule')
434
+ lr_schedule = optax.constant_schedule(value = lr)
435
+ adamw = optax.adamw(
436
+ learning_rate = lr_schedule,
437
+ b1 = 0.9,
438
+ b2 = 0.999,
439
+ eps = eps,
440
+ weight_decay = weight_decay #0.01 # 0.0001
441
+ )
442
+ optim = optax.chain(
443
+ optax.clip_by_global_norm(max_norm = 1.0),
444
+ adamw
445
+ )
446
+ partition_optimizers = {
447
+ 'trainable': optim,
448
+ 'frozen': optax.set_to_zero()
449
+ }
450
+ tx = optax.multi_transform(partition_optimizers, self.param_partitions)
451
+ state = TrainState.create(
452
+ apply_fn = self.model.__call__,
453
+ params = self.params,
454
+ tx = tx
455
+ )
456
+ validation_rng, train_rngs = jax.random.split(self.rng)
457
+ train_rngs = jax.random.split(train_rngs, jax.local_device_count())
458
+
459
+ def train_step(state: TrainState, batch: Dict[str, jax.Array], train_rng: jax.random.PRNGKeyArray):
460
+ def compute_loss(
461
+ params: Dict[str, Any],
462
+ batch: Dict[str, jax.Array],
463
+ sample_rng: jax.random.PRNGKeyArray # unused, dataloader provides everything
464
+ ) -> jax.Array:
465
+ # 'latent_model_input': latent_model_input
466
+ # 'encoder_hidden_states': encoder_hidden_states
467
+ # 'timesteps': timesteps
468
+ # 'noise': noise
469
+ latent_model_input = batch['latent_model_input']
470
+ encoder_hidden_states = batch['encoder_hidden_states']
471
+ timesteps = batch['timesteps']
472
+ noise = batch['noise']
473
+ model_pred = self.model.apply(
474
+ { 'params': params },
475
+ latent_model_input,
476
+ timesteps,
477
+ encoder_hidden_states
478
+ ).sample
479
+ loss = (noise - model_pred) ** 2
480
+ loss = loss.mean()
481
+ return loss
482
+ grad_fn = jax.value_and_grad(compute_loss)
483
+
484
+ def loss_and_grad(
485
+ train_rng: jax.random.PRNGKeyArray
486
+ ) -> Tuple[jax.Array, Any, jax.random.PRNGKeyArray]:
487
+ sample_rng, train_rng = jax.random.split(train_rng, 2)
488
+ loss, grad = grad_fn(state.params, batch, sample_rng)
489
+ return loss, grad, train_rng
490
+
491
+ loss, grad, new_train_rng = loss_and_grad(train_rng)
492
+ # self.log(grad) # NOTE uncomment to visualize gradient
493
+ grad = jax.lax.pmean(grad, axis_name = 'batch')
494
+ new_state = state.apply_gradients(grads = grad)
495
+ metrics: Dict[str, Any] = { 'loss': loss }
496
+ metrics = jax.lax.pmean(metrics, axis_name = 'batch')
497
+ def l2(xs) -> jax.Array:
498
+ return jnp.sqrt(sum([jnp.vdot(x, x) for x in jax.tree_util.tree_leaves(xs)]))
499
+ metrics['l2_grads'] = l2(jax.tree_util.tree_leaves(grad))
500
+
501
+ return new_state, metrics, new_train_rng
502
+
503
+ p_train_step = jax.pmap(fun = train_step, axis_name = 'batch', donate_argnums = (0, ))
504
+ state = jax_utils.replicate(state)
505
+
506
+ train_metrics = []
507
+ train_metric = None
508
+
509
+ global_step: int = 0
510
+
511
+ if jax.process_index() == 0:
512
+ self._init_tracker_meta()
513
+ hyper_params = {
514
+ 'lr': lr,
515
+ 'lr_warmup': warmup,
516
+ 'lr_decay': decay,
517
+ 'weight_decay': weight_decay,
518
+ 'total_steps': total_steps,
519
+ 'batch_size': dataloader.batch_size // self.num_devices,
520
+ 'num_frames': num_frames,
521
+ 'sample_size': self.sample_size,
522
+ 'num_devices': self.num_devices,
523
+ 'seed': self.seed,
524
+ 'use_memory_efficient_attention': self.model.use_memory_efficient_attention,
525
+ 'only_temporal': self.only_temporal,
526
+ 'dtype': self.dtype_str,
527
+ 'param_dtype': self.param_dtype,
528
+ 'pretrained_model': self.pretrained_model,
529
+ 'model_config': self.model.config
530
+ }
531
+ if self._use_wandb:
532
+ self.log('Setting up wandb')
533
+ self._setup_wandb(hyper_params)
534
+ self.log(hyper_params)
535
+ output_path = os.path.join(output_dir, str(global_step), 'unet')
536
+ self.log(f'saving checkpoint to {output_path}')
537
+ self.model.save_pretrained(
538
+ save_directory = output_path,
539
+ params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
540
+ is_main_process = True
541
+ )
542
+
543
+ pbar_epoch = tqdm(
544
+ total = epochs,
545
+ desc = 'Epochs',
546
+ smoothing = 1,
547
+ position = 0,
548
+ dynamic_ncols = True,
549
+ leave = True,
550
+ disable = jax.process_index() > 0
551
+ )
552
+ steps_per_epoch = len(dataloader) # TODO dataloader
553
+ for epoch in range(epochs):
554
+ pbar_steps = tqdm(
555
+ total = steps_per_epoch,
556
+ desc = 'Steps',
557
+ position = 1,
558
+ smoothing = 0.1,
559
+ dynamic_ncols = True,
560
+ leave = True,
561
+ disable = jax.process_index() > 0
562
+ )
563
+ for batch in dataloader:
564
+ # keep input + gt as float32, results in fp32 loss and grad
565
+ # otherwise uncomment the following to cast to the model dtype
566
+ # batch = { k: (v.astype(self.dtype) if v.dtype == np.float32 else v) for k,v in batch.items() }
567
+ batch = shard(batch)
568
+ state, train_metric, train_rngs = p_train_step(
569
+ state, batch, train_rngs
570
+ )
571
+ train_metrics.append(train_metric)
572
+ if global_step % log_every_step == 0 and jax.process_index() == 0:
573
+ train_metrics = jax_utils.unreplicate(train_metrics)
574
+ train_metrics = jax.tree_util.tree_map(lambda *m: jnp.array(m).mean(), *train_metrics)
575
+ if global_step == 0:
576
+ self.log(f'grad dtype: {train_metrics["l2_grads"].dtype}')
577
+ self.log(f'loss dtype: {train_metrics["loss"].dtype}')
578
+ train_metrics_dict = { k: v.item() for k, v in train_metrics.items() }
579
+ train_metrics_dict['lr'] = lr_schedule(global_step).item()
580
+ self.log_metrics(train_metrics_dict, step = global_step, epoch = epoch)
581
+ train_metrics = []
582
+ pbar_steps.update(1)
583
+ global_step += 1
584
+ if epoch % save_every_epoch == 0 and jax.process_index() == 0:
585
+ output_path = os.path.join(output_dir, str(global_step), 'unet')
586
+ self.log(f'saving checkpoint to {output_path}')
587
+ self.model.save_pretrained(
588
+ save_directory = output_path,
589
+ params = self.get_params_from_state(state),#jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params)),
590
+ is_main_process = True
591
+ )
592
+ self.log(f'checkpoint saved ')
593
+ if epoch % sample_every_epoch == 0 and jax.process_index() == 0:
594
+ images = self.sample(
595
+ params = state.params,
596
+ replicate_params = False,
597
+ prompt = 'dancing person',
598
+ image_path = 'testimage.png',
599
+ num_frames = num_frames,
600
+ steps = 50,
601
+ cfg = 9.0,
602
+ unload_after_usage = False
603
+ )
604
+ os.makedirs(os.path.join('image_output', str(epoch)), exist_ok = True)
605
+ for i, im in enumerate(images):
606
+ im.save(os.path.join('image_output', str(epoch), str(i).zfill(5) + '.png'), optimize = True)
607
+ pbar_epoch.update(1)
608
+
makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_blocks.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+
8
+ from .flax_attention_pseudo3d import TransformerPseudo3DModel
9
+ from .flax_resnet_pseudo3d import ResnetBlockPseudo3D, DownsamplePseudo3D, UpsamplePseudo3D
10
+
11
+
12
+ class UNetMidBlockPseudo3DCrossAttn(nn.Module):
13
+ in_channels: int
14
+ num_layers: int = 1
15
+ attn_num_head_channels: int = 1
16
+ use_memory_efficient_attention: bool = False
17
+ dtype: jnp.dtype = jnp.float32
18
+
19
+ def setup(self) -> None:
20
+ resnets = [
21
+ ResnetBlockPseudo3D(
22
+ in_channels = self.in_channels,
23
+ out_channels = self.in_channels,
24
+ dtype = self.dtype
25
+ )
26
+ ]
27
+ attentions = []
28
+ for _ in range(self.num_layers):
29
+ attn_block = TransformerPseudo3DModel(
30
+ in_channels = self.in_channels,
31
+ num_attention_heads = self.attn_num_head_channels,
32
+ attention_head_dim = self.in_channels // self.attn_num_head_channels,
33
+ num_layers = 1,
34
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
35
+ dtype = self.dtype
36
+ )
37
+ attentions.append(attn_block)
38
+ res_block = ResnetBlockPseudo3D(
39
+ in_channels = self.in_channels,
40
+ out_channels = self.in_channels,
41
+ dtype = self.dtype
42
+ )
43
+ resnets.append(res_block)
44
+ self.attentions = attentions
45
+ self.resnets = resnets
46
+
47
+ def __call__(self,
48
+ hidden_states: jax.Array,
49
+ temb: jax.Array,
50
+ encoder_hidden_states = jax.Array
51
+ ) -> jax.Array:
52
+ hidden_states = self.resnets[0](hidden_states, temb)
53
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
54
+ hidden_states = attn(hidden_states, encoder_hidden_states)
55
+ hidden_states = resnet(hidden_states, temb)
56
+ return hidden_states
57
+
58
+
59
+ class CrossAttnDownBlockPseudo3D(nn.Module):
60
+ in_channels: int
61
+ out_channels: int
62
+ num_layers: int = 1
63
+ attn_num_head_channels: int = 1
64
+ add_downsample: bool = True
65
+ use_memory_efficient_attention: bool = False
66
+ dtype: jnp.dtype = jnp.float32
67
+
68
+ def setup(self) -> None:
69
+ attentions = []
70
+ resnets = []
71
+ for i in range(self.num_layers):
72
+ in_channels = self.in_channels if i == 0 else self.out_channels
73
+ res_block = ResnetBlockPseudo3D(
74
+ in_channels = in_channels,
75
+ out_channels = self.out_channels,
76
+ dtype = self.dtype
77
+ )
78
+ resnets.append(res_block)
79
+ attn_block = TransformerPseudo3DModel(
80
+ in_channels = self.out_channels,
81
+ num_attention_heads = self.attn_num_head_channels,
82
+ attention_head_dim = self.out_channels // self.attn_num_head_channels,
83
+ num_layers = 1,
84
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
85
+ dtype = self.dtype
86
+ )
87
+ attentions.append(attn_block)
88
+ self.resnets = resnets
89
+ self.attentions = attentions
90
+
91
+ if self.add_downsample:
92
+ self.downsamplers_0 = DownsamplePseudo3D(
93
+ out_channels = self.out_channels,
94
+ dtype = self.dtype
95
+ )
96
+ else:
97
+ self.downsamplers_0 = None
98
+
99
+ def __call__(self,
100
+ hidden_states: jax.Array,
101
+ temb: jax.Array,
102
+ encoder_hidden_states: jax.Array
103
+ ) -> Tuple[jax.Array, jax.Array]:
104
+ output_states = ()
105
+ for resnet, attn in zip(self.resnets, self.attentions):
106
+ hidden_states = resnet(hidden_states, temb)
107
+ hidden_states = attn(hidden_states, encoder_hidden_states)
108
+ output_states += (hidden_states, )
109
+ if self.add_downsample:
110
+ hidden_states = self.downsamplers_0(hidden_states)
111
+ output_states += (hidden_states, )
112
+ return hidden_states, output_states
113
+
114
+
115
+ class DownBlockPseudo3D(nn.Module):
116
+ in_channels: int
117
+ out_channels: int
118
+ num_layers: int = 1
119
+ add_downsample: bool = True
120
+ dtype: jnp.dtype = jnp.float32
121
+
122
+ def setup(self) -> None:
123
+ resnets = []
124
+ for i in range(self.num_layers):
125
+ in_channels = self.in_channels if i == 0 else self.out_channels
126
+ res_block = ResnetBlockPseudo3D(
127
+ in_channels = in_channels,
128
+ out_channels = self.out_channels,
129
+ dtype = self.dtype
130
+ )
131
+ resnets.append(res_block)
132
+ self.resnets = resnets
133
+ if self.add_downsample:
134
+ self.downsamplers_0 = DownsamplePseudo3D(
135
+ out_channels = self.out_channels,
136
+ dtype = self.dtype
137
+ )
138
+ else:
139
+ self.downsamplers_0 = None
140
+
141
+ def __call__(self,
142
+ hidden_states: jax.Array,
143
+ temb: jax.Array
144
+ ) -> Tuple[jax.Array, jax.Array]:
145
+ output_states = ()
146
+ for resnet in self.resnets:
147
+ hidden_states = resnet(hidden_states, temb)
148
+ output_states += (hidden_states, )
149
+ if self.add_downsample:
150
+ hidden_states = self.downsamplers_0(hidden_states)
151
+ output_states += (hidden_states, )
152
+ return hidden_states, output_states
153
+
154
+
155
+ class CrossAttnUpBlockPseudo3D(nn.Module):
156
+ in_channels: int
157
+ out_channels: int
158
+ prev_output_channels: int
159
+ num_layers: int = 1
160
+ attn_num_head_channels: int = 1
161
+ add_upsample: bool = True
162
+ use_memory_efficient_attention: bool = False
163
+ dtype: jnp.dtype = jnp.float32
164
+
165
+ def setup(self) -> None:
166
+ resnets = []
167
+ attentions = []
168
+ for i in range(self.num_layers):
169
+ res_skip_channels = self.in_channels if (i == self.num_layers -1) else self.out_channels
170
+ resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
171
+ res_block = ResnetBlockPseudo3D(
172
+ in_channels = resnet_in_channels + res_skip_channels,
173
+ out_channels = self.out_channels,
174
+ dtype = self.dtype
175
+ )
176
+ resnets.append(res_block)
177
+ attn_block = TransformerPseudo3DModel(
178
+ in_channels = self.out_channels,
179
+ num_attention_heads = self.attn_num_head_channels,
180
+ attention_head_dim = self.out_channels // self.attn_num_head_channels,
181
+ num_layers = 1,
182
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
183
+ dtype = self.dtype
184
+ )
185
+ attentions.append(attn_block)
186
+ self.resnets = resnets
187
+ self.attentions = attentions
188
+ if self.add_upsample:
189
+ self.upsamplers_0 = UpsamplePseudo3D(
190
+ out_channels = self.out_channels,
191
+ dtype = self.dtype
192
+ )
193
+ else:
194
+ self.upsamplers_0 = None
195
+
196
+ def __call__(self,
197
+ hidden_states: jax.Array,
198
+ res_hidden_states_tuple: Tuple[jax.Array, ...],
199
+ temb: jax.Array,
200
+ encoder_hidden_states: jax.Array
201
+ ) -> jax.Array:
202
+ for resnet, attn in zip(self.resnets, self.attentions):
203
+ res_hidden_states = res_hidden_states_tuple[-1]
204
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
205
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis = -1)
206
+ hidden_states = resnet(hidden_states, temb)
207
+ hidden_states = attn(hidden_states, encoder_hidden_states)
208
+ if self.add_upsample:
209
+ hidden_states = self.upsamplers_0(hidden_states)
210
+ return hidden_states
211
+
212
+
213
+ class UpBlockPseudo3D(nn.Module):
214
+ in_channels: int
215
+ out_channels: int
216
+ prev_output_channels: int
217
+ num_layers: int = 1
218
+ add_upsample: bool = True
219
+ dtype: jnp.dtype = jnp.float32
220
+
221
+ def setup(self) -> None:
222
+ resnets = []
223
+ for i in range(self.num_layers):
224
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
225
+ resnet_in_channels = self.prev_output_channels if i == 0 else self.out_channels
226
+ res_block = ResnetBlockPseudo3D(
227
+ in_channels = resnet_in_channels + res_skip_channels,
228
+ out_channels = self.out_channels,
229
+ dtype = self.dtype
230
+ )
231
+ resnets.append(res_block)
232
+ self.resnets = resnets
233
+ if self.add_upsample:
234
+ self.upsamplers_0 = UpsamplePseudo3D(
235
+ out_channels = self.out_channels,
236
+ dtype = self.dtype
237
+ )
238
+ else:
239
+ self.upsamplers_0 = None
240
+
241
+ def __call__(self,
242
+ hidden_states: jax.Array,
243
+ res_hidden_states_tuple: Tuple[jax.Array, ...],
244
+ temb: jax.Array
245
+ ) -> jax.Array:
246
+ for resnet in self.resnets:
247
+ res_hidden_states = res_hidden_states_tuple[-1]
248
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
249
+ hidden_states = jnp.concatenate([hidden_states, res_hidden_states], axis = -1)
250
+ hidden_states = resnet(hidden_states, temb)
251
+ if self.add_upsample:
252
+ hidden_states = self.upsamplers_0(hidden_states)
253
+ return hidden_states
254
+
makeavid_sd/makeavid_sd/flax_impl/flax_unet_pseudo3d_condition.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Tuple, Union
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
10
+ from diffusers.models.modeling_flax_utils import FlaxModelMixin
11
+ from diffusers.utils import BaseOutput
12
+
13
+ from .flax_unet_pseudo3d_blocks import (
14
+ CrossAttnDownBlockPseudo3D,
15
+ CrossAttnUpBlockPseudo3D,
16
+ DownBlockPseudo3D,
17
+ UpBlockPseudo3D,
18
+ UNetMidBlockPseudo3DCrossAttn
19
+ )
20
+ #from flax_embeddings import (
21
+ # TimestepEmbedding,
22
+ # Timesteps
23
+ #)
24
+ from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
25
+ from .flax_resnet_pseudo3d import ConvPseudo3D
26
+
27
+
28
+ class UNetPseudo3DConditionOutput(BaseOutput):
29
+ sample: jax.Array
30
+
31
+
32
+ @flax_register_to_config
33
+ class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
34
+ sample_size: Union[int, Tuple[int, int]] = (64, 64)
35
+ in_channels: int = 4
36
+ out_channels: int = 4
37
+ down_block_types: Tuple[str] = (
38
+ "CrossAttnDownBlockPseudo3D",
39
+ "CrossAttnDownBlockPseudo3D",
40
+ "CrossAttnDownBlockPseudo3D",
41
+ "DownBlockPseudo3D"
42
+ )
43
+ up_block_types: Tuple[str] = (
44
+ "UpBlockPseudo3D",
45
+ "CrossAttnUpBlockPseudo3D",
46
+ "CrossAttnUpBlockPseudo3D",
47
+ "CrossAttnUpBlockPseudo3D"
48
+ )
49
+ block_out_channels: Tuple[int] = (
50
+ 320,
51
+ 640,
52
+ 1280,
53
+ 1280
54
+ )
55
+ layers_per_block: int = 2
56
+ attention_head_dim: Union[int, Tuple[int]] = 8
57
+ cross_attention_dim: int = 768
58
+ flip_sin_to_cos: bool = True
59
+ freq_shift: int = 0
60
+ use_memory_efficient_attention: bool = False
61
+ dtype: jnp.dtype = jnp.float32
62
+ param_dtype: str = 'float32'
63
+
64
+ def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
65
+ if self.param_dtype == 'bfloat16':
66
+ param_dtype = jnp.bfloat16
67
+ elif self.param_dtype == 'float16':
68
+ param_dtype = jnp.float16
69
+ elif self.param_dtype == 'float32':
70
+ param_dtype = jnp.float32
71
+ else:
72
+ raise ValueError(f'unknown parameter type: {self.param_dtype}')
73
+ sample_size = self.sample_size
74
+ if isinstance(sample_size, int):
75
+ sample_size = (sample_size, sample_size)
76
+ sample_shape = (1, self.in_channels, 1, *sample_size)
77
+ sample = jnp.zeros(sample_shape, dtype = param_dtype)
78
+ timesteps = jnp.ones((1, ), dtype = jnp.int32)
79
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
80
+ params_rng, dropout_rng = jax.random.split(rng)
81
+ rngs = { "params": params_rng, "dropout": dropout_rng }
82
+ return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]
83
+
84
+ def setup(self) -> None:
85
+ if isinstance(self.attention_head_dim, int):
86
+ attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
87
+ else:
88
+ attention_head_dim = self.attention_head_dim
89
+ time_embed_dim = self.block_out_channels[0] * 4
90
+ self.conv_in = ConvPseudo3D(
91
+ features = self.block_out_channels[0],
92
+ kernel_size = (3, 3),
93
+ strides = (1, 1),
94
+ padding = ((1, 1), (1, 1)),
95
+ dtype = self.dtype
96
+ )
97
+ self.time_proj = FlaxTimesteps(
98
+ dim = self.block_out_channels[0],
99
+ flip_sin_to_cos = self.flip_sin_to_cos,
100
+ freq_shift = self.freq_shift
101
+ )
102
+ self.time_embedding = FlaxTimestepEmbedding(
103
+ time_embed_dim = time_embed_dim,
104
+ dtype = self.dtype
105
+ )
106
+ down_blocks = []
107
+ output_channels = self.block_out_channels[0]
108
+ for i, down_block_type in enumerate(self.down_block_types):
109
+ input_channels = output_channels
110
+ output_channels = self.block_out_channels[i]
111
+ is_final_block = i == len(self.block_out_channels) - 1
112
+ # allows loading 3d models with old layer type names in their configs
113
+ # eg. 2D instead of Pseudo3D, like lxj's timelapse model
114
+ if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
115
+ down_block = CrossAttnDownBlockPseudo3D(
116
+ in_channels = input_channels,
117
+ out_channels = output_channels,
118
+ num_layers = self.layers_per_block,
119
+ attn_num_head_channels = attention_head_dim[i],
120
+ add_downsample = not is_final_block,
121
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
122
+ dtype = self.dtype
123
+ )
124
+ elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
125
+ down_block = DownBlockPseudo3D(
126
+ in_channels = input_channels,
127
+ out_channels = output_channels,
128
+ num_layers = self.layers_per_block,
129
+ add_downsample = not is_final_block,
130
+ dtype = self.dtype
131
+ )
132
+ else:
133
+ raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
134
+ down_blocks.append(down_block)
135
+ self.down_blocks = down_blocks
136
+ self.mid_block = UNetMidBlockPseudo3DCrossAttn(
137
+ in_channels = self.block_out_channels[-1],
138
+ attn_num_head_channels = attention_head_dim[-1],
139
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
140
+ dtype = self.dtype
141
+ )
142
+ up_blocks = []
143
+ reversed_block_out_channels = list(reversed(self.block_out_channels))
144
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
145
+ output_channels = reversed_block_out_channels[0]
146
+ for i, up_block_type in enumerate(self.up_block_types):
147
+ prev_output_channels = output_channels
148
+ output_channels = reversed_block_out_channels[i]
149
+ input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
150
+ is_final_block = i == len(self.block_out_channels) - 1
151
+ if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
152
+ up_block = CrossAttnUpBlockPseudo3D(
153
+ in_channels = input_channels,
154
+ out_channels = output_channels,
155
+ prev_output_channels = prev_output_channels,
156
+ num_layers = self.layers_per_block + 1,
157
+ attn_num_head_channels = reversed_attention_head_dim[i],
158
+ add_upsample = not is_final_block,
159
+ use_memory_efficient_attention = self.use_memory_efficient_attention,
160
+ dtype = self.dtype
161
+ )
162
+ elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
163
+ up_block = UpBlockPseudo3D(
164
+ in_channels = input_channels,
165
+ out_channels = output_channels,
166
+ prev_output_channels = prev_output_channels,
167
+ num_layers = self.layers_per_block + 1,
168
+ add_upsample = not is_final_block,
169
+ dtype = self.dtype
170
+ )
171
+ else:
172
+ raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
173
+ up_blocks.append(up_block)
174
+ self.up_blocks = up_blocks
175
+ self.conv_norm_out = nn.GroupNorm(
176
+ num_groups = 32,
177
+ epsilon = 1e-5
178
+ )
179
+ self.conv_out = ConvPseudo3D(
180
+ features = self.out_channels,
181
+ kernel_size = (3, 3),
182
+ strides = (1, 1),
183
+ padding = ((1, 1), (1, 1)),
184
+ dtype = self.dtype
185
+ )
186
+
187
+ def __call__(self,
188
+ sample: jax.Array,
189
+ timesteps: jax.Array,
190
+ encoder_hidden_states: jax.Array,
191
+ return_dict: bool = True
192
+ ) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
193
+ if timesteps.dtype != jnp.float32:
194
+ timesteps = timesteps.astype(dtype = jnp.float32)
195
+ if len(timesteps.shape) == 0:
196
+ timesteps = jnp.expand_dims(timesteps, 0)
197
+ # b,c,f,h,w -> b,f,h,w,c
198
+ sample = sample.transpose((0, 2, 3, 4, 1))
199
+
200
+ t_emb = self.time_proj(timesteps)
201
+ t_emb = self.time_embedding(t_emb)
202
+ sample = self.conv_in(sample)
203
+ down_block_res_samples = (sample, )
204
+ for down_block in self.down_blocks:
205
+ if isinstance(down_block, CrossAttnDownBlockPseudo3D):
206
+ sample, res_samples = down_block(
207
+ hidden_states = sample,
208
+ temb = t_emb,
209
+ encoder_hidden_states = encoder_hidden_states
210
+ )
211
+ elif isinstance(down_block, DownBlockPseudo3D):
212
+ sample, res_samples = down_block(
213
+ hidden_states = sample,
214
+ temb = t_emb
215
+ )
216
+ else:
217
+ raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
218
+ down_block_res_samples += res_samples
219
+ sample = self.mid_block(
220
+ hidden_states = sample,
221
+ temb = t_emb,
222
+ encoder_hidden_states = encoder_hidden_states
223
+ )
224
+ for up_block in self.up_blocks:
225
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
226
+ down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
227
+ if isinstance(up_block, CrossAttnUpBlockPseudo3D):
228
+ sample = up_block(
229
+ hidden_states = sample,
230
+ temb = t_emb,
231
+ encoder_hidden_states = encoder_hidden_states,
232
+ res_hidden_states_tuple = res_samples
233
+ )
234
+ elif isinstance(up_block, UpBlockPseudo3D):
235
+ sample = up_block(
236
+ hidden_states = sample,
237
+ temb = t_emb,
238
+ res_hidden_states_tuple = res_samples
239
+ )
240
+ else:
241
+ raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
242
+ sample = self.conv_norm_out(sample)
243
+ sample = nn.silu(sample)
244
+ sample = self.conv_out(sample)
245
+
246
+ # b,f,h,w,c -> b,c,f,h,w
247
+ sample = sample.transpose((0, 4, 1, 2, 3))
248
+ if not return_dict:
249
+ return (sample, )
250
+ return UNetPseudo3DConditionOutput(sample = sample)
251
+
makeavid_sd/makeavid_sd/flax_impl/train.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ _ = jax.device_count() # ugly hack to prevent tpu comms to lock/race or smth smh
4
+
5
+ from typing import Tuple, Optional
6
+ import os
7
+ from argparse import ArgumentParser
8
+
9
+ from flax_trainer import FlaxTrainerUNetPseudo3D
10
+ from dataset import load_dataset
11
+
12
+ def train(
13
+ dataset_path: str,
14
+ model_path: str,
15
+ output_dir: str,
16
+ dataset_cache_dir: Optional[str] = None,
17
+ from_pt: bool = True,
18
+ convert2d: bool = False,
19
+ only_temporal: bool = True,
20
+ sample_size: Tuple[int, int] = (64, 64),
21
+ lr: float = 5e-5,
22
+ batch_size: int = 1,
23
+ num_frames: int = 24,
24
+ epochs: int = 10,
25
+ warmup: float = 0.1,
26
+ decay: float = 0.0,
27
+ weight_decay: float = 1e-2,
28
+ log_every_step: int = 50,
29
+ save_every_epoch: int = 1,
30
+ sample_every_epoch: int = 1,
31
+ seed: int = 0,
32
+ dtype: str = 'bfloat16',
33
+ param_dtype: str = 'float32',
34
+ use_memory_efficient_attention: bool = True,
35
+ verbose: bool = True,
36
+ use_wandb: bool = False
37
+ ) -> None:
38
+ log = lambda x: print(x) if verbose else None
39
+ log('\n----------------')
40
+ log('Init trainer')
41
+ trainer = FlaxTrainerUNetPseudo3D(
42
+ model_path = model_path,
43
+ from_pt = from_pt,
44
+ convert2d = convert2d,
45
+ sample_size = sample_size,
46
+ seed = seed,
47
+ dtype = dtype,
48
+ param_dtype = param_dtype,
49
+ use_memory_efficient_attention = use_memory_efficient_attention,
50
+ verbose = verbose,
51
+ only_temporal = only_temporal
52
+ )
53
+ log('\n----------------')
54
+ log('Init dataset')
55
+ dataloader = load_dataset(
56
+ dataset_path = dataset_path,
57
+ model_path = model_path,
58
+ cache_dir = dataset_cache_dir,
59
+ batch_size = batch_size * trainer.num_devices,
60
+ num_frames = num_frames,
61
+ num_workers = min(trainer.num_devices * 2, os.cpu_count() - 1),
62
+ as_numpy = True,
63
+ shuffle = True
64
+ )
65
+ log('\n----------------')
66
+ log('Train')
67
+ if use_wandb:
68
+ trainer.enable_wandb()
69
+ trainer.train(
70
+ dataloader = dataloader,
71
+ epochs = epochs,
72
+ num_frames = num_frames,
73
+ log_every_step = log_every_step,
74
+ save_every_epoch = save_every_epoch,
75
+ sample_every_epoch = sample_every_epoch,
76
+ lr = lr,
77
+ warmup = warmup,
78
+ decay = decay,
79
+ weight_decay = weight_decay,
80
+ output_dir = output_dir
81
+ )
82
+ log('\n----------------')
83
+ log('Done')
84
+
85
+
86
+ if __name__ == '__main__':
87
+ parser = ArgumentParser()
88
+ bool_type = lambda x: x.lower() in ['true', '1', 'yes']
89
+ parser.add_argument('-v', '--verbose', type = bool_type, default = True)
90
+ parser.add_argument('-d', '--dataset_path', required = True)
91
+ parser.add_argument('-m', '--model_path', required = True)
92
+ parser.add_argument('-o', '--output_dir', required = True)
93
+ parser.add_argument('-b', '--batch_size', type = int, default = 1)
94
+ parser.add_argument('-f', '--num_frames', type = int, default = 24)
95
+ parser.add_argument('-e', '--epochs', type = int, default = 2)
96
+ parser.add_argument('--only_temporal', type = bool_type, default = True)
97
+ parser.add_argument('--dataset_cache_dir', type = str, default = None)
98
+ parser.add_argument('--from_pt', type = bool_type, default = True)
99
+ parser.add_argument('--convert2d', type = bool_type, default = False)
100
+ parser.add_argument('--lr', type = float, default = 1e-4)
101
+ parser.add_argument('--warmup', type = float, default = 0.1)
102
+ parser.add_argument('--decay', type = float, default = 0.0)
103
+ parser.add_argument('--weight_decay', type = float, default = 1e-2)
104
+ parser.add_argument('--sample_size', type = int, nargs = 2, default = [64, 64])
105
+ parser.add_argument('--log_every_step', type = int, default = 250)
106
+ parser.add_argument('--save_every_epoch', type = int, default = 1)
107
+ parser.add_argument('--sample_every_epoch', type = int, default = 1)
108
+ parser.add_argument('--seed', type = int, default = 0)
109
+ parser.add_argument('--use_memory_efficient_attention', type = bool_type, default = True)
110
+ parser.add_argument('--dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'bfloat16')
111
+ parser.add_argument('--param_dtype', choices = ['float32', 'bfloat16', 'float16'], default = 'float32')
112
+ parser.add_argument('--wandb', type = bool_type, default = False)
113
+ args = parser.parse_args()
114
+ args.sample_size = tuple(args.sample_size)
115
+ if args.verbose:
116
+ print(args)
117
+ train(
118
+ dataset_path = args.dataset_path,
119
+ model_path = args.model_path,
120
+ from_pt = args.from_pt,
121
+ convert2d = args.convert2d,
122
+ only_temporal = args.only_temporal,
123
+ output_dir = args.output_dir,
124
+ dataset_cache_dir = args.dataset_cache_dir,
125
+ batch_size = args.batch_size,
126
+ num_frames = args.num_frames,
127
+ epochs = args.epochs,
128
+ lr = args.lr,
129
+ warmup = args.warmup,
130
+ decay = args.decay,
131
+ weight_decay = args.weight_decay,
132
+ sample_size = args.sample_size,
133
+ seed = args.seed,
134
+ dtype = args.dtype,
135
+ param_dtype = args.param_dtype,
136
+ use_memory_efficient_attention = args.use_memory_efficient_attention,
137
+ log_every_step = args.log_every_step,
138
+ save_every_epoch = args.save_every_epoch,
139
+ sample_every_epoch = args.sample_every_epoch,
140
+ verbose = args.verbose,
141
+ use_wandb = args.wandb
142
+ )
143
+
makeavid_sd/makeavid_sd/flax_impl/train.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+
3
+ #export WANDB_API_KEY="your_api_key"
4
+ export WANDB_ENTITY="tempofunk"
5
+ export WANDB_JOB_TYPE="train"
6
+ export WANDB_PROJECT="makeavid-sd-tpu"
7
+
8
+ python train.py \
9
+ --dataset_path ../storage/dataset/tempofunk-s \
10
+ --model_path ../storage/trained_models/ep20 \
11
+ --from_pt False \
12
+ --convert2d False \
13
+ --only_temporal True \
14
+ --output_dir ../storage/output \
15
+ --batch_size 1 \
16
+ --num_frames 24 \
17
+ --epochs 20 \
18
+ --lr 0.00005 \
19
+ --warmup 0.1 \
20
+ --decay 0.0 \
21
+ --sample_size 64 64 \
22
+ --log_every_step 50 \
23
+ --save_every_epoch 1 \
24
+ --sample_every_epoch 1 \
25
+ --seed 2 \
26
+ --use_memory_efficient_attention True \
27
+ --dtype bfloat16 \
28
+ --param_dtype float32 \
29
+ --verbose True \
30
+ --dataset_cache_dir ../storage/cache/hf/datasets \
31
+ --wandb True
32
+
33
+ # sudo rm /tmp/libtpu_lockfile
34
+
makeavid_sd/makeavid_sd/inference.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Union, Tuple, List, Dict
3
+ import os
4
+ import gc
5
+ from functools import partial
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ import numpy as np
10
+
11
+ from flax.core.frozen_dict import FrozenDict
12
+ from flax import jax_utils
13
+ from flax.training.common_utils import shard
14
+ from PIL import Image
15
+ import einops
16
+
17
+ from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
18
+ from diffusers import (
19
+ FlaxDDIMScheduler,
20
+ FlaxDDPMScheduler,
21
+ FlaxPNDMScheduler,
22
+ FlaxLMSDiscreteScheduler,
23
+ FlaxDPMSolverMultistepScheduler,
24
+ FlaxKarrasVeScheduler,
25
+ FlaxScoreSdeVeScheduler
26
+ )
27
+
28
+ from transformers import FlaxCLIPTextModel, CLIPTokenizer
29
+
30
+ from .flax_impl.flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
31
+
32
+ SchedulerType = Union[
33
+ FlaxDDIMScheduler,
34
+ FlaxDDPMScheduler,
35
+ FlaxPNDMScheduler,
36
+ FlaxLMSDiscreteScheduler,
37
+ FlaxDPMSolverMultistepScheduler,
38
+ FlaxKarrasVeScheduler,
39
+ FlaxScoreSdeVeScheduler
40
+ ]
41
+
42
+ def dtypestr(x: jnp.dtype):
43
+ if x == jnp.float32: return 'float32'
44
+ elif x == jnp.float16: return 'float16'
45
+ elif x == jnp.bfloat16: return 'bfloat16'
46
+ else: raise
47
+ def castto(dtype, m, x):
48
+ if dtype == jnp.float32: return m.to_fp32(x)
49
+ elif dtype == jnp.float16: return m.to_fp16(x)
50
+ elif dtype == jnp.bfloat16: return m.to_bf16(x)
51
+ else: raise
52
+
53
+ class InferenceUNetPseudo3D:
54
+ def __init__(self,
55
+ model_path: str,
56
+ scheduler_cls: SchedulerType = FlaxDDIMScheduler,
57
+ dtype: jnp.dtype = jnp.float16,
58
+ hf_auth_token: Union[str, None] = None
59
+ ) -> None:
60
+ self.dtype = dtype
61
+ self.model_path = model_path
62
+ self.hf_auth_token = hf_auth_token
63
+
64
+ self.params: Dict[str, FrozenDict[str, Any]] = {}
65
+ unet, unet_params = UNetPseudo3DConditionModel.from_pretrained(
66
+ self.model_path,
67
+ subfolder = 'unet',
68
+ from_pt = False,
69
+ sample_size = (64, 64),
70
+ dtype = self.dtype,
71
+ param_dtype = dtypestr(self.dtype),
72
+ use_memory_efficient_attention = True,
73
+ use_auth_token = self.hf_auth_token
74
+ )
75
+ self.unet: UNetPseudo3DConditionModel = unet
76
+ unet_params = castto(self.dtype, self.unet, unet_params)
77
+ self.params['unet'] = FrozenDict(unet_params)
78
+ del unet_params
79
+ vae, vae_params = FlaxAutoencoderKL.from_pretrained(
80
+ self.model_path,
81
+ subfolder = 'vae',
82
+ from_pt = True,
83
+ dtype = self.dtype,
84
+ use_auth_token = self.hf_auth_token
85
+ )
86
+ self.vae: FlaxAutoencoderKL = vae
87
+ vae_params = castto(self.dtype, self.vae, vae_params)
88
+ self.params['vae'] = FrozenDict(vae_params)
89
+ del vae_params
90
+ text_encoder = FlaxCLIPTextModel.from_pretrained(
91
+ self.model_path,
92
+ subfolder = 'text_encoder',
93
+ from_pt = True,
94
+ dtype = self.dtype,
95
+ use_auth_token = self.hf_auth_token
96
+ )
97
+ text_encoder_params = text_encoder.params
98
+ del text_encoder._params
99
+ text_encoder_params = castto(self.dtype, text_encoder, text_encoder_params)
100
+ self.text_encoder: FlaxCLIPTextModel = text_encoder
101
+ self.params['text_encoder'] = FrozenDict(text_encoder_params)
102
+ del text_encoder_params
103
+ imunet, imunet_params = FlaxUNet2DConditionModel.from_pretrained(
104
+ 'runwayml/stable-diffusion-v1-5',
105
+ subfolder = 'unet',
106
+ from_pt = True,
107
+ dtype = self.dtype,
108
+ use_memory_efficient_attention = True,
109
+ use_auth_token = self.hf_auth_token
110
+ )
111
+ imunet_params = castto(self.dtype, imunet, imunet_params)
112
+ self.imunet: FlaxUNet2DConditionModel = imunet
113
+ self.params['imunet'] = FrozenDict(imunet_params)
114
+ del imunet_params
115
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
116
+ self.model_path,
117
+ subfolder = 'tokenizer',
118
+ use_auth_token = self.hf_auth_token
119
+ )
120
+ scheduler, scheduler_state = scheduler_cls.from_pretrained(
121
+ self.model_path,
122
+ subfolder = 'scheduler',
123
+ dtype = jnp.float32,
124
+ use_auth_token = self.hf_api_key
125
+ )
126
+ self.scheduler: scheduler_cls = scheduler
127
+ self.params['scheduler'] = scheduler_state
128
+ self.vae_scale_factor: int = int(2 ** (len(self.vae.config.block_out_channels) - 1))
129
+ self.device_count = jax.device_count()
130
+ gc.collect()
131
+
132
+ def set_scheduler(self, scheduler_cls: SchedulerType) -> None:
133
+ scheduler, scheduler_state = scheduler_cls.from_pretrained(
134
+ self.model_path,
135
+ subfolder = 'scheduler',
136
+ dtype = jnp.float32,
137
+ use_auth_token = self.hf_api_key
138
+ )
139
+ self.scheduler: scheduler_cls = scheduler
140
+ self.params['scheduler'] = scheduler_state
141
+
142
+ def prepare_inputs(self,
143
+ prompt: List[str],
144
+ neg_prompt: List[str],
145
+ hint_image: List[Image.Image],
146
+ mask_image: List[Image.Image],
147
+ width: int,
148
+ height: int
149
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: # prompt, neg_prompt, hint_image, mask_image
150
+ tokens = self.tokenizer(
151
+ prompt,
152
+ truncation = True,
153
+ return_overflowing_tokens = False,
154
+ max_length = 77, #self.text_encoder.config.max_length defaults to 20 if its not in the config smh
155
+ padding = 'max_length',
156
+ return_tensors = 'np'
157
+ ).input_ids
158
+ tokens = jnp.array(tokens, dtype = jnp.int32)
159
+ neg_tokens = self.tokenizer(
160
+ neg_prompt,
161
+ truncation = True,
162
+ return_overflowing_tokens = False,
163
+ max_length = 77,
164
+ padding = 'max_length',
165
+ return_tensors = 'np'
166
+ ).input_ids
167
+ neg_tokens = jnp.array(neg_tokens, dtype = jnp.int32)
168
+ for i,im in enumerate(hint_image):
169
+ if im.size != (width, height):
170
+ hint_image[i] = hint_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
171
+ for i,im in enumerate(mask_image):
172
+ if im.size != (width, height):
173
+ mask_image[i] = mask_image[i].resize((width, height), resample = Image.Resampling.LANCZOS)
174
+ # b,h,w,c | c == 3
175
+ hint = jnp.concatenate(
176
+ [ jnp.expand_dims(np.asarray(x.convert('RGB')), axis = 0) for x in hint_image ],
177
+ axis = 0
178
+ ).astype(jnp.float32)
179
+ # scale -1,1
180
+ hint = (hint / 255) * 2 - 1
181
+ # b,h,w,c | c == 1
182
+ mask = jnp.concatenate(
183
+ [ jnp.expand_dims(np.asarray(x.convert('L')), axis = (0, -1)) for x in mask_image ],
184
+ axis = 0
185
+ ).astype(jnp.float32)
186
+ # scale -1,1
187
+ mask = (mask / 255) * 2 - 1
188
+ # binarize mask
189
+ mask = mask.at[mask < 0.5].set(0)
190
+ mask = mask.at[mask >= 0.5].set(1)
191
+ # mask
192
+ hint = hint * (mask < 0.5)
193
+ # b,h,w,c -> b,c,h,w
194
+ hint = hint.transpose((0,3,1,2))
195
+ mask = mask.transpose((0,3,1,2))
196
+ return tokens, neg_tokens, hint, mask
197
+
198
+ def generate(self,
199
+ prompt: Union[str, List[str]],
200
+ inference_steps: int,
201
+ hint_image: Union[Image.Image, List[Image.Image], None] = None,
202
+ mask_image: Union[Image.Image, List[Image.Image], None] = None,
203
+ neg_prompt: Union[str, List[str]] = '',
204
+ cfg: float = 10.0,
205
+ num_frames: int = 24,
206
+ width: int = 512,
207
+ height: int = 512,
208
+ seed: int = 0
209
+ ) -> List[List[Image.Image]]:
210
+ assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
211
+ assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
212
+ assert width % 32 == 0, f'width must be divisible by 32 but is {width}'
213
+ assert height % 32 == 0, f'height must be divisible by 32 but is {height}'
214
+ if isinstance(prompt, str):
215
+ prompt = [ prompt ]
216
+ batch_size = len(prompt)
217
+ assert batch_size % self.device_count == 0, f'batch size must be multiple of {self.device_count}'
218
+ if hint_image is None:
219
+ hint_image = Image.new('RGB', (width, height), color = (0,0,0))
220
+ use_imagegen = True
221
+ else:
222
+ use_imagegen = False
223
+ if isinstance(hint_image, Image.Image):
224
+ hint_image = [ hint_image ] * batch_size
225
+ assert len(hint_image) == batch_size, f'number of hint images must be equal to batch size {batch_size} but is {len(hint_image)}'
226
+ if mask_image is None:
227
+ mask_image = Image.new('L', hint_image[0].size, color = 0)
228
+ if isinstance(mask_image, Image.Image):
229
+ mask_image = [ mask_image ] * batch_size
230
+ assert len(mask_image) == batch_size, f'number of mask images must be equal to batch size {batch_size} but is {len(mask_image)}'
231
+ if isinstance(neg_prompt, str):
232
+ neg_prompt = [ neg_prompt ] * batch_size
233
+ assert len(neg_prompt) == batch_size, f'number of negative prompts must be equal to batch size {batch_size} but is {len(neg_prompt)}'
234
+ tokens, neg_tokens, hint, mask = self.prepare_inputs(
235
+ prompt = prompt,
236
+ neg_prompt = neg_prompt,
237
+ hint_image = hint_image,
238
+ mask_image = mask_image,
239
+ width = width,
240
+ height = height
241
+ )
242
+ # NOTE splitting rngs is not deterministic,
243
+ # running on different device counts gives different seeds
244
+ #rng = jax.random.PRNGKey(seed)
245
+ #rngs = jax.random.split(rng, self.device_count)
246
+ # manually assign seeded RNGs to devices for reproducability
247
+ rngs = jnp.array([ jax.random.PRNGKey(seed + i) for i in range(self.device_count) ])
248
+ params = jax_utils.replicate(self.params)
249
+ tokens = shard(tokens)
250
+ neg_tokens = shard(neg_tokens)
251
+ hint = shard(hint)
252
+ mask = shard(mask)
253
+ images = _p_generate(self,
254
+ tokens,
255
+ neg_tokens,
256
+ hint,
257
+ mask,
258
+ inference_steps,
259
+ num_frames,
260
+ height,
261
+ width,
262
+ cfg,
263
+ rngs,
264
+ params,
265
+ use_imagegen
266
+ )
267
+ if images.ndim == 5:
268
+ images = einops.rearrange(images, 'd f c h w -> (d f) h w c')
269
+ else:
270
+ images = einops.rearrange(images, 'f c h w -> f h w c')
271
+ # to cpu
272
+ images = np.array(images)
273
+ images = [ Image.fromarray(x) for x in images ]
274
+ return images
275
+
276
+ def _generate(self,
277
+ tokens: jnp.ndarray,
278
+ neg_tokens: jnp.ndarray,
279
+ hint: jnp.ndarray,
280
+ mask: jnp.ndarray,
281
+ inference_steps: int,
282
+ num_frames,
283
+ height,
284
+ width,
285
+ cfg: float,
286
+ rng: jax.random.KeyArray,
287
+ params: Union[Dict[str, Any], FrozenDict[str, Any]],
288
+ use_imagegen: bool
289
+ ) -> List[Image.Image]:
290
+ batch_size = tokens.shape[0]
291
+ latent_h = height // self.vae_scale_factor
292
+ latent_w = width // self.vae_scale_factor
293
+ latent_shape = (
294
+ batch_size,
295
+ self.vae.config.latent_channels,
296
+ num_frames,
297
+ latent_h,
298
+ latent_w
299
+ )
300
+ encoded_prompt = self.text_encoder(tokens, params = params['text_encoder'])[0]
301
+ encoded_neg_prompt = self.text_encoder(neg_tokens, params = params['text_encoder'])[0]
302
+
303
+ if use_imagegen:
304
+ image_latent_shape = (batch_size, self.vae.config.latent_channels, latent_h, latent_w)
305
+ image_latents = jax.random.normal(
306
+ rng,
307
+ shape = image_latent_shape,
308
+ dtype = jnp.float32
309
+ ) * params['scheduler'].init_noise_sigma
310
+ image_scheduler_state = self.scheduler.set_timesteps(
311
+ params['scheduler'],
312
+ num_inference_steps = inference_steps,
313
+ shape = image_latents.shape
314
+ )
315
+ def image_sample_loop(step, args):
316
+ image_latents, image_scheduler_state = args
317
+ t = image_scheduler_state.timesteps[step]
318
+ tt = jnp.broadcast_to(t, image_latents.shape[0])
319
+ latents_input = self.scheduler.scale_model_input(image_scheduler_state, image_latents, t)
320
+ noise_pred = self.imunet.apply(
321
+ {'params': params['imunet']},
322
+ latents_input,
323
+ tt,
324
+ encoder_hidden_states = encoded_prompt
325
+ ).sample
326
+ noise_pred_uncond = self.imunet.apply(
327
+ {'params': params['imunet']},
328
+ latents_input,
329
+ tt,
330
+ encoder_hidden_states = encoded_neg_prompt
331
+ ).sample
332
+ noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
333
+ image_latents, image_scheduler_state = self.scheduler.step(
334
+ image_scheduler_state,
335
+ noise_pred.astype(jnp.float32),
336
+ t,
337
+ image_latents
338
+ ).to_tuple()
339
+ return image_latents, image_scheduler_state
340
+ image_latents, _ = jax.lax.fori_loop(
341
+ 0, inference_steps,
342
+ image_sample_loop,
343
+ (image_latents, image_scheduler_state)
344
+ )
345
+ hint = image_latents
346
+ else:
347
+ hint = self.vae.apply(
348
+ {'params': params['vae']},
349
+ hint,
350
+ method = self.vae.encode
351
+ ).latent_dist.mean * self.vae.config.scaling_factor
352
+ # NOTE vae keeps channels last for encode, but rearranges to channels first for decode
353
+ # b0 h1 w2 c3 -> b0 c3 h1 w2
354
+ hint = hint.transpose((0, 3, 1, 2))
355
+
356
+ hint = jnp.expand_dims(hint, axis = 2).repeat(num_frames, axis = 2)
357
+ mask = jax.image.resize(mask, (*mask.shape[:-2], *hint.shape[-2:]), method = 'nearest')
358
+ mask = jnp.expand_dims(mask, axis = 2).repeat(num_frames, axis = 2)
359
+ # NOTE jax normal distribution is shit with float16 + bfloat16
360
+ # SEE https://github.com/google/jax/discussions/13798
361
+ # generate random at float32
362
+ latents = jax.random.normal(
363
+ rng,
364
+ shape = latent_shape,
365
+ dtype = jnp.float32
366
+ ) * params['scheduler'].init_noise_sigma
367
+ scheduler_state = self.scheduler.set_timesteps(
368
+ params['scheduler'],
369
+ num_inference_steps = inference_steps,
370
+ shape = latents.shape
371
+ )
372
+
373
+ def sample_loop(step, args):
374
+ latents, scheduler_state = args
375
+ t = scheduler_state.timesteps[step]#jnp.array(scheduler_state.timesteps, dtype = jnp.int32)[step]
376
+ tt = jnp.broadcast_to(t, latents.shape[0])
377
+ latents_input = self.scheduler.scale_model_input(scheduler_state, latents, t)
378
+ latents_input = jnp.concatenate([latents_input, mask, hint], axis = 1)
379
+ noise_pred = self.unet.apply(
380
+ { 'params': params['unet'] },
381
+ latents_input,
382
+ tt,
383
+ encoded_prompt
384
+ ).sample
385
+ noise_pred_uncond = self.unet.apply(
386
+ { 'params': params['unet'] },
387
+ latents_input,
388
+ tt,
389
+ encoded_neg_prompt
390
+ ).sample
391
+ noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
392
+ latents, scheduler_state = self.scheduler.step(
393
+ scheduler_state,
394
+ noise_pred.astype(jnp.float32),
395
+ t,
396
+ latents
397
+ ).to_tuple()
398
+ return latents, scheduler_state
399
+
400
+ latents, _ = jax.lax.fori_loop(
401
+ 0, inference_steps,
402
+ sample_loop,
403
+ (latents, scheduler_state)
404
+ )
405
+ latents = 1 / self.vae.config.scaling_factor * latents
406
+ latents = einops.rearrange(latents, 'b c f h w -> (b f) c h w')
407
+ num_images = len(latents)
408
+ images_out = jnp.zeros(
409
+ (
410
+ num_images,
411
+ self.vae.config.out_channels,
412
+ height,
413
+ width
414
+ ),
415
+ dtype = self.dtype
416
+ )
417
+ def decode_loop(step, images_out):
418
+ # NOTE vae keeps channels last for encode, but rearranges to channels first for decode
419
+ im = self.vae.apply(
420
+ { 'params': params['vae'] },
421
+ jnp.expand_dims(latents[step], axis = 0),
422
+ method = self.vae.decode
423
+ ).sample
424
+ images_out = images_out.at[step].set(im[0])
425
+ return images_out
426
+ images_out = jax.lax.fori_loop(0, num_images, decode_loop, images_out)
427
+ images_out = ((images_out / 2 + 0.5) * 255).round().clip(0, 255).astype(jnp.uint8)
428
+ return images_out
429
+
430
+
431
+ @partial(
432
+ jax.pmap,
433
+ in_axes = ( # 0 -> split across batch dim, None -> duplicate
434
+ None, # 0 inference_class
435
+ 0, # 1 tokens
436
+ 0, # 2 neg_tokens
437
+ 0, # 3 hint
438
+ 0, # 4 mask
439
+ None, # 5 inference_steps
440
+ None, # 6 num_frames
441
+ None, # 7 height
442
+ None, # 8 width
443
+ None, # 9 cfg
444
+ 0, # 10 rng
445
+ 0, # 11 params
446
+ None, # 12 use_imagegen
447
+ ),
448
+ static_broadcasted_argnums = ( # trigger recompilation on change
449
+ 0, # inference_class
450
+ 5, # inference_steps
451
+ 6, # num_frames
452
+ 7, # height
453
+ 8, # width
454
+ 12, # use_imagegen
455
+ )
456
+ )
457
+ def _p_generate(
458
+ inference_class: InferenceUNetPseudo3D,
459
+ tokens,
460
+ neg_tokens,
461
+ hint,
462
+ mask,
463
+ inference_steps,
464
+ num_frames,
465
+ height,
466
+ width,
467
+ cfg,
468
+ rng,
469
+ params,
470
+ use_imagegen
471
+ ):
472
+ return inference_class._generate(
473
+ tokens,
474
+ neg_tokens,
475
+ hint,
476
+ mask,
477
+ inference_steps,
478
+ num_frames,
479
+ height,
480
+ width,
481
+ cfg,
482
+ rng,
483
+ params,
484
+ use_imagegen
485
+ )
486
+
makeavid_sd/makeavid_sd/torch_impl/__init__.py ADDED
File without changes
makeavid_sd/makeavid_sd/torch_impl/torch_attention_pseudo3d.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+ from einops import rearrange
8
+
9
+ from diffusers.models.attention_processor import Attention as CrossAttention
10
+ #from torch_cross_attention import CrossAttention
11
+
12
+
13
+ class TransformerPseudo3DModelOutput:
14
+ def __init__(self, sample: torch.FloatTensor) -> None:
15
+ self.sample = sample
16
+
17
+
18
+ class TransformerPseudo3DModel(nn.Module):
19
+ def __init__(self,
20
+ num_attention_heads: int = 16,
21
+ attention_head_dim: int = 88,
22
+ in_channels: Optional[int] = None,
23
+ num_layers: int = 1,
24
+ dropout: float = 0.0,
25
+ norm_num_groups: int = 32,
26
+ cross_attention_dim: Optional[int] = None,
27
+ attention_bias: bool = False
28
+ ) -> None:
29
+ super().__init__()
30
+ self.num_attention_heads = num_attention_heads
31
+ self.attention_head_dim = attention_head_dim
32
+ inner_dim = num_attention_heads * attention_head_dim
33
+
34
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
35
+ # Define whether input is continuous or discrete depending on configuration
36
+ # its continuous
37
+
38
+ # 2. Define input layers
39
+ self.in_channels = in_channels
40
+
41
+ self.norm = torch.nn.GroupNorm(
42
+ num_groups = norm_num_groups,
43
+ num_channels = in_channels,
44
+ eps = 1e-6,
45
+ affine = True
46
+ )
47
+ self.proj_in = nn.Conv2d(
48
+ in_channels,
49
+ inner_dim,
50
+ kernel_size = 1,
51
+ stride = 1,
52
+ padding = 0
53
+ )
54
+
55
+ # 3. Define transformers blocks
56
+ self.transformer_blocks = nn.ModuleList(
57
+ [
58
+ BasicTransformerBlock(
59
+ inner_dim,
60
+ num_attention_heads,
61
+ attention_head_dim,
62
+ dropout = dropout,
63
+ cross_attention_dim = cross_attention_dim,
64
+ attention_bias = attention_bias,
65
+ )
66
+ for _ in range(num_layers)
67
+ ]
68
+ )
69
+
70
+ # 4. Define output layers
71
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size = 1, stride = 1, padding = 0)
72
+
73
+ def forward(self,
74
+ hidden_states: torch.Tensor,
75
+ encoder_hidden_states: Optional[torch.Tensor] = None,
76
+ timestep: torch.long = None
77
+ ) -> TransformerPseudo3DModelOutput:
78
+ """
79
+ Args:
80
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
81
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
82
+ hidden_states
83
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
84
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
85
+ self-attention.
86
+ timestep ( `torch.long`, *optional*):
87
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
88
+ return_dict (`bool`, *optional*, defaults to `True`):
89
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
90
+
91
+ Returns:
92
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
93
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
94
+ tensor.
95
+ """
96
+ b, c, *_, h, w = hidden_states.shape
97
+ is_video = hidden_states.ndim == 5
98
+ f = None
99
+ if is_video:
100
+ b, c, f, h, w = hidden_states.shape
101
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
102
+ #encoder_hidden_states = encoder_hidden_states.repeat_interleave(f, 0)
103
+
104
+ # 1. Input
105
+ batch, channel, height, weight = hidden_states.shape
106
+ residual = hidden_states
107
+ hidden_states = self.norm(hidden_states)
108
+ hidden_states = self.proj_in(hidden_states)
109
+ inner_dim = hidden_states.shape[1]
110
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
111
+
112
+ # 2. Blocks
113
+ for block in self.transformer_blocks:
114
+ hidden_states = block(
115
+ hidden_states,
116
+ context = encoder_hidden_states,
117
+ timestep = timestep,
118
+ frames_length = f,
119
+ height = height,
120
+ weight = weight
121
+ )
122
+
123
+ # 3. Output
124
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
125
+ hidden_states = self.proj_out(hidden_states)
126
+ output = hidden_states + residual
127
+
128
+ if is_video:
129
+ output = rearrange(output, '(b f) c h w -> b c f h w', b = b)
130
+
131
+ return TransformerPseudo3DModelOutput(sample = output)
132
+
133
+
134
+
135
+ class BasicTransformerBlock(nn.Module):
136
+ r"""
137
+ A basic Transformer block.
138
+
139
+ Parameters:
140
+ dim (`int`): The number of channels in the input and output.
141
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
142
+ attention_head_dim (`int`): The number of channels in each head.
143
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
144
+ cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
145
+ num_embeds_ada_norm (:
146
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
147
+ attention_bias (:
148
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
149
+ """
150
+
151
+ def __init__(self,
152
+ dim: int,
153
+ num_attention_heads: int,
154
+ attention_head_dim: int,
155
+ dropout: float = 0.0,
156
+ cross_attention_dim: Optional[int] = None,
157
+ attention_bias: bool = False,
158
+ ) -> None:
159
+ super().__init__()
160
+ self.attn1 = CrossAttention(
161
+ query_dim = dim,
162
+ heads = num_attention_heads,
163
+ dim_head = attention_head_dim,
164
+ dropout = dropout,
165
+ bias = attention_bias
166
+ ) # is a self-attention
167
+ self.ff = FeedForward(dim, dropout = dropout)
168
+ self.attn2 = CrossAttention(
169
+ query_dim = dim,
170
+ cross_attention_dim = cross_attention_dim,
171
+ heads = num_attention_heads,
172
+ dim_head = attention_head_dim,
173
+ dropout = dropout,
174
+ bias = attention_bias
175
+ ) # is self-attn if context is none
176
+ self.attn_temporal = CrossAttention(
177
+ query_dim = dim,
178
+ heads = num_attention_heads,
179
+ dim_head = attention_head_dim,
180
+ dropout = dropout,
181
+ bias = attention_bias
182
+ ) # is a self-attention
183
+
184
+ # layer norms
185
+ self.norm1 = nn.LayerNorm(dim)
186
+ self.norm2 = nn.LayerNorm(dim)
187
+ self.norm_temporal = nn.LayerNorm(dim)
188
+ self.norm3 = nn.LayerNorm(dim)
189
+
190
+ def forward(self,
191
+ hidden_states: torch.Tensor,
192
+ context: Optional[torch.Tensor] = None,
193
+ timestep: torch.int64 = None,
194
+ frames_length: Optional[int] = None,
195
+ height: Optional[int] = None,
196
+ weight: Optional[int] = None
197
+ ) -> torch.Tensor:
198
+ if context is not None and frames_length is not None:
199
+ context = context.repeat_interleave(frames_length, 0)
200
+ # 1. Self-Attention
201
+ norm_hidden_states = (
202
+ self.norm1(hidden_states)
203
+ )
204
+ hidden_states = self.attn1(norm_hidden_states) + hidden_states
205
+
206
+ # 2. Cross-Attention
207
+ norm_hidden_states = (
208
+ self.norm2(hidden_states)
209
+ )
210
+ hidden_states = self.attn2(
211
+ norm_hidden_states,
212
+ encoder_hidden_states = context
213
+ ) + hidden_states
214
+
215
+ # append temporal attention
216
+ if frames_length is not None:
217
+ hidden_states = rearrange(
218
+ hidden_states,
219
+ '(b f) (h w) c -> (b h w) f c',
220
+ f = frames_length,
221
+ h = height,
222
+ w = weight
223
+ )
224
+ norm_hidden_states = (
225
+ self.norm_temporal(hidden_states)
226
+ )
227
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
228
+ hidden_states = rearrange(
229
+ hidden_states,
230
+ '(b h w) f c -> (b f) (h w) c',
231
+ f = frames_length,
232
+ h = height,
233
+ w = weight
234
+ )
235
+
236
+ # 3. Feed-forward
237
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
238
+ return hidden_states
239
+
240
+
241
+ class FeedForward(nn.Module):
242
+ r"""
243
+ A feed-forward layer.
244
+
245
+ Parameters:
246
+ dim (`int`): The number of channels in the input.
247
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
248
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
249
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
250
+ """
251
+
252
+ def __init__(self,
253
+ dim: int,
254
+ dim_out: Optional[int] = None,
255
+ mult: int = 4,
256
+ dropout: float = 0.0
257
+ ) -> None:
258
+ super().__init__()
259
+ inner_dim = int(dim * mult)
260
+ dim_out = dim_out if dim_out is not None else dim
261
+
262
+ geglu = GEGLU(dim, inner_dim)
263
+
264
+ self.net = nn.ModuleList([])
265
+ # project in
266
+ self.net.append(geglu)
267
+ # project dropout
268
+ self.net.append(nn.Dropout(dropout))
269
+ # project out
270
+ self.net.append(nn.Linear(inner_dim, dim_out))
271
+
272
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
273
+ for module in self.net:
274
+ hidden_states = module(hidden_states)
275
+ return hidden_states
276
+
277
+
278
+ # feedforward
279
+ class GEGLU(nn.Module):
280
+ r"""
281
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
282
+
283
+ Parameters:
284
+ dim_in (`int`): The number of channels in the input.
285
+ dim_out (`int`): The number of channels in the output.
286
+ """
287
+
288
+ def __init__(self, dim_in: int, dim_out: int) -> None:
289
+ super().__init__()
290
+ self.proj = nn.Linear(dim_in, dim_out * 2)
291
+
292
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim = -1)
294
+ return hidden_states * F.gelu(gate)
makeavid_sd/makeavid_sd/torch_impl/torch_cross_attention.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class CrossAttention(nn.Module):
6
+ r"""
7
+ A cross attention layer.
8
+
9
+ Parameters:
10
+ query_dim (`int`): The number of channels in the query.
11
+ cross_attention_dim (`int`, *optional*):
12
+ The number of channels in the context. If not given, defaults to `query_dim`.
13
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
14
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
15
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
16
+ bias (`bool`, *optional*, defaults to False):
17
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
18
+ """
19
+
20
+ def __init__(self,
21
+ query_dim: int,
22
+ cross_attention_dim: Optional[int] = None,
23
+ heads: int = 8,
24
+ dim_head: int = 64,
25
+ dropout: float = 0.0,
26
+ bias: bool = False
27
+ ):
28
+ super().__init__()
29
+ inner_dim = dim_head * heads
30
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
31
+
32
+ self.scale = dim_head**-0.5
33
+ self.heads = heads
34
+ self.n_heads = heads
35
+ self.d_head = dim_head
36
+
37
+ self.to_q = nn.Linear(query_dim, inner_dim, bias = bias)
38
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
39
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias = bias)
40
+
41
+ self.to_out = nn.ModuleList([])
42
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
43
+ self.to_out.append(nn.Dropout(dropout))
44
+ try:
45
+ # You can install flash attention by cloning their Github repo,
46
+ # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
47
+ # and then running `python setup.py install`
48
+ from flash_attn.flash_attention import FlashAttention
49
+ self.flash = FlashAttention()
50
+ # Set the scale for scaled dot-product attention.
51
+ self.flash.softmax_scale = self.scale
52
+ # Set to `None` if it's not installed
53
+ except ImportError:
54
+ self.flash = None
55
+
56
+ def reshape_heads_to_batch_dim(self, tensor):
57
+ batch_size, seq_len, dim = tensor.shape
58
+ head_size = self.heads
59
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
60
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
61
+ return tensor
62
+
63
+ def reshape_batch_dim_to_heads(self, tensor):
64
+ batch_size, seq_len, dim = tensor.shape
65
+ head_size = self.heads
66
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
67
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
68
+ return tensor
69
+
70
+ def forward(self,
71
+ hidden_states: torch.Tensor,
72
+ encoder_hidden_states: Optional[torch.Tensor] = None,
73
+ mask: Optional[torch.Tensor] = None
74
+ ) -> torch.Tensor:
75
+ batch_size, sequence_length, _ = hidden_states.shape
76
+ is_self = encoder_hidden_states is None
77
+ # attention, what we cannot get enough of
78
+ query = self.to_q(hidden_states)
79
+ has_cond = encoder_hidden_states is not None
80
+
81
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
82
+ key = self.to_k(encoder_hidden_states)
83
+ value = self.to_v(encoder_hidden_states)
84
+
85
+ dim = query.shape[-1]
86
+
87
+ if self.flash is not None and not has_cond and self.d_head <= 64:
88
+ hidden_states = self.flash_attention(query, key, value)
89
+ else:
90
+ hidden_states = self.normal_attention(query, key, value, is_self)
91
+
92
+ # linear proj
93
+ hidden_states = self.to_out[0](hidden_states)
94
+ # dropout
95
+ hidden_states = self.to_out[1](hidden_states)
96
+ return hidden_states
97
+
98
+ def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
99
+ """
100
+ #### Flash Attention
101
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
102
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
103
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
104
+ """
105
+
106
+ # Get batch size and number of elements along sequence axis (`width * height`)
107
+ batch_size, seq_len, _ = q.shape
108
+
109
+ # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
110
+ # shape `[batch_size, seq_len, 3, n_heads * d_head]`
111
+ qkv = torch.stack((q, k, v), dim = 2)
112
+ # Split the heads
113
+ qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
114
+
115
+ # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
116
+ # fit this size.
117
+ if self.d_head <= 32:
118
+ pad = 32 - self.d_head
119
+ elif self.d_head <= 64:
120
+ pad = 64 - self.d_head
121
+ elif self.d_head <= 128:
122
+ pad = 128 - self.d_head
123
+ else:
124
+ raise ValueError(f'Head size ${self.d_head} too large for Flash Attention')
125
+
126
+ # Pad the heads
127
+ if pad:
128
+ qkv = torch.cat((qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim = -1)
129
+
130
+ # Compute attention
131
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
132
+ # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
133
+ out, _ = self.flash(qkv)
134
+ # Truncate the extra head size
135
+ out = out[:, :, :, :self.d_head]
136
+ # Reshape to `[batch_size, seq_len, n_heads * d_head]`
137
+ out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
138
+
139
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
140
+ return out
141
+
142
+ def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, is_self: bool):
143
+ """
144
+ #### Normal Attention
145
+
146
+ :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
147
+ :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
148
+ :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
149
+ """
150
+ # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
151
+ q = q.view(*q.shape[:2], self.n_heads, -1)
152
+ k = k.view(*k.shape[:2], self.n_heads, -1)
153
+ v = v.view(*v.shape[:2], self.n_heads, -1)
154
+
155
+ # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
156
+ attn = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
157
+ # Compute softmax
158
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
159
+ half = attn.shape[0] // 2
160
+ attn[half:] = attn[half:].softmax(dim = -1)
161
+ attn[:half] = attn[:half].softmax(dim = -1)
162
+
163
+ # Compute attention output
164
+ # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
165
+ out = torch.einsum('bhij,bjhd->bihd', attn, v)
166
+
167
+ # Reshape to `[batch_size, height * width, n_heads * d_head]`
168
+ out = out.reshape(*out.shape[:2], -1)
169
+
170
+ # Map to `[batch_size, height * width, d_model]` with a linear layer
171
+ return out
makeavid_sd/makeavid_sd/torch_impl/torch_embeddings.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+ def get_timestep_embedding(
6
+ timesteps: torch.Tensor,
7
+ embedding_dim: int,
8
+ flip_sin_to_cos: bool = False,
9
+ downscale_freq_shift: float = 1,
10
+ scale: float = 1,
11
+ max_period: int = 10000,
12
+ ) -> torch.Tensor:
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
15
+
16
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
17
+ These may be fractional.
18
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
19
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
20
+ """
21
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
22
+
23
+ half_dim = embedding_dim // 2
24
+ exponent = -math.log(max_period) * torch.arange(
25
+ start = 0,
26
+ end = half_dim,
27
+ dtype = torch.float32,
28
+ device = timesteps.device
29
+ )
30
+ exponent = exponent / (half_dim - downscale_freq_shift)
31
+
32
+ emb = torch.exp(exponent)
33
+ emb = timesteps[:, None].float() * emb[None, :]
34
+
35
+ # scale embeddings
36
+ emb = scale * emb
37
+
38
+ # concat sine and cosine embeddings
39
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = -1)
40
+
41
+ # flip sine and cosine embeddings
42
+ if flip_sin_to_cos:
43
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim = -1)
44
+
45
+ # zero pad
46
+ if embedding_dim % 2 == 1:
47
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
48
+ return emb
49
+
50
+
51
+ class TimestepEmbedding(nn.Module):
52
+ def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
53
+ super().__init__()
54
+
55
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
56
+ self.act = None
57
+ if act_fn == "silu":
58
+ self.act = nn.SiLU()
59
+ elif act_fn == "mish":
60
+ self.act = nn.Mish()
61
+
62
+ if out_dim is not None:
63
+ time_embed_dim_out = out_dim
64
+ else:
65
+ time_embed_dim_out = time_embed_dim
66
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
67
+
68
+ def forward(self, sample):
69
+ sample = self.linear_1(sample)
70
+
71
+ if self.act is not None:
72
+ sample = self.act(sample)
73
+
74
+ sample = self.linear_2(sample)
75
+ return sample
76
+
77
+
78
+ class Timesteps(nn.Module):
79
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
80
+ super().__init__()
81
+ self.num_channels = num_channels
82
+ self.flip_sin_to_cos = flip_sin_to_cos
83
+ self.downscale_freq_shift = downscale_freq_shift
84
+
85
+ def forward(self, timesteps):
86
+ t_emb = get_timestep_embedding(
87
+ timesteps,
88
+ self.num_channels,
89
+ flip_sin_to_cos=self.flip_sin_to_cos,
90
+ downscale_freq_shift=self.downscale_freq_shift,
91
+ )
92
+ return t_emb
makeavid_sd/makeavid_sd/torch_impl/torch_resnet_pseudo3d.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ class Pseudo3DConv(nn.Module):
7
+ def __init__(
8
+ self,
9
+ dim,
10
+ dim_out,
11
+ kernel_size,
12
+ **kwargs
13
+ ):
14
+ super().__init__()
15
+
16
+ self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size, **kwargs)
17
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size, padding=kernel_size // 2)
18
+ self.temporal_conv = nn.Conv1d(dim_out, dim_out, 3, padding=1)
19
+
20
+ nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
21
+ nn.init.zeros_(self.temporal_conv.bias.data)
22
+
23
+ def forward(
24
+ self,
25
+ x,
26
+ convolve_across_time = True
27
+ ):
28
+ b, c, *_, h, w = x.shape
29
+
30
+ is_video = x.ndim == 5
31
+ convolve_across_time &= is_video
32
+
33
+ if is_video:
34
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
35
+
36
+ #with torch.no_grad():
37
+ # x = self.spatial_conv(x)
38
+ x = self.spatial_conv(x)
39
+
40
+ if is_video:
41
+ x = rearrange(x, '(b f) c h w -> b c f h w', b = b)
42
+ b, c, *_, h, w = x.shape
43
+
44
+ if not convolve_across_time:
45
+ return x
46
+
47
+ if is_video:
48
+ x = rearrange(x, 'b c f h w -> (b h w) c f')
49
+ x = self.temporal_conv(x)
50
+ x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
51
+ return x
52
+
53
+ class Upsample2D(nn.Module):
54
+ """
55
+ An upsampling layer with an optional convolution.
56
+
57
+ Parameters:
58
+ channels: channels in the inputs and outputs.
59
+ use_conv: a bool determining if a convolution is applied.
60
+ use_conv_transpose:
61
+ out_channels:
62
+ """
63
+
64
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
65
+ super().__init__()
66
+ self.channels = channels
67
+ self.out_channels = out_channels or channels
68
+ self.use_conv = use_conv
69
+ self.use_conv_transpose = use_conv_transpose
70
+ self.name = name
71
+
72
+ conv = None
73
+ if use_conv_transpose:
74
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
75
+ elif use_conv:
76
+ conv = Pseudo3DConv(self.channels, self.out_channels, 3, padding=1)
77
+
78
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
79
+ if name == "conv":
80
+ self.conv = conv
81
+ else:
82
+ self.Conv2d_0 = conv
83
+
84
+ def forward(self, hidden_states, output_size=None):
85
+ assert hidden_states.shape[1] == self.channels
86
+
87
+ if self.use_conv_transpose:
88
+ return self.conv(hidden_states)
89
+
90
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
91
+ # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
92
+ # https://github.com/pytorch/pytorch/issues/86679
93
+ dtype = hidden_states.dtype
94
+ if dtype == torch.bfloat16:
95
+ hidden_states = hidden_states.to(torch.float32)
96
+
97
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
98
+ if hidden_states.shape[0] >= 64:
99
+ hidden_states = hidden_states.contiguous()
100
+
101
+ b, c, *_, h, w = hidden_states.shape
102
+
103
+ is_video = hidden_states.ndim == 5
104
+
105
+ if is_video:
106
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
107
+
108
+ # if `output_size` is passed we force the interpolation output
109
+ # size and do not make use of `scale_factor=2`
110
+ if output_size is None:
111
+ hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
112
+ else:
113
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
114
+
115
+ if is_video:
116
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
117
+
118
+ # If the input is bfloat16, we cast back to bfloat16
119
+ if dtype == torch.bfloat16:
120
+ hidden_states = hidden_states.to(dtype)
121
+
122
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
123
+ if self.use_conv:
124
+ if self.name == "conv":
125
+ hidden_states = self.conv(hidden_states)
126
+ else:
127
+ hidden_states = self.Conv2d_0(hidden_states)
128
+
129
+ return hidden_states
130
+
131
+
132
+ class Downsample2D(nn.Module):
133
+ """
134
+ A downsampling layer with an optional convolution.
135
+
136
+ Parameters:
137
+ channels: channels in the inputs and outputs.
138
+ use_conv: a bool determining if a convolution is applied.
139
+ out_channels:
140
+ padding:
141
+ """
142
+
143
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
144
+ super().__init__()
145
+ self.channels = channels
146
+ self.out_channels = out_channels or channels
147
+ self.use_conv = use_conv
148
+ self.padding = padding
149
+ stride = 2
150
+ self.name = name
151
+
152
+ if use_conv:
153
+ conv = Pseudo3DConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
154
+ else:
155
+ assert self.channels == self.out_channels
156
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
157
+
158
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
159
+ if name == "conv":
160
+ self.Conv2d_0 = conv
161
+ self.conv = conv
162
+ elif name == "Conv2d_0":
163
+ self.conv = conv
164
+ else:
165
+ self.conv = conv
166
+
167
+ def forward(self, hidden_states):
168
+ assert hidden_states.shape[1] == self.channels
169
+ if self.use_conv and self.padding == 0:
170
+ pad = (0, 1, 0, 1)
171
+ hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
172
+
173
+ assert hidden_states.shape[1] == self.channels
174
+ if self.use_conv:
175
+ hidden_states = self.conv(hidden_states)
176
+ else:
177
+ b, c, *_, h, w = hidden_states.shape
178
+ is_video = hidden_states.ndim == 5
179
+ if is_video:
180
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
181
+ hidden_states = self.conv(hidden_states)
182
+ if is_video:
183
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b = b)
184
+
185
+ return hidden_states
186
+
187
+
188
+ class ResnetBlockPseudo3D(nn.Module):
189
+ def __init__(
190
+ self,
191
+ *,
192
+ in_channels,
193
+ out_channels=None,
194
+ conv_shortcut=False,
195
+ dropout=0.0,
196
+ temb_channels=512,
197
+ groups=32,
198
+ groups_out=None,
199
+ pre_norm=True,
200
+ eps=1e-6,
201
+ time_embedding_norm="default",
202
+ kernel=None,
203
+ output_scale_factor=1.0,
204
+ use_in_shortcut=None,
205
+ up=False,
206
+ down=False,
207
+ ):
208
+ super().__init__()
209
+ self.pre_norm = pre_norm
210
+ self.pre_norm = True
211
+ self.in_channels = in_channels
212
+ out_channels = in_channels if out_channels is None else out_channels
213
+ self.out_channels = out_channels
214
+ self.use_conv_shortcut = conv_shortcut
215
+ self.time_embedding_norm = time_embedding_norm
216
+ self.up = up
217
+ self.down = down
218
+ self.output_scale_factor = output_scale_factor
219
+ print('OUTPUT_SCALE_FACTOR:', output_scale_factor)
220
+
221
+ if groups_out is None:
222
+ groups_out = groups
223
+
224
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
225
+
226
+ self.conv1 = Pseudo3DConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
227
+
228
+ if temb_channels is not None:
229
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
230
+ else:
231
+ self.time_emb_proj = None
232
+
233
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
234
+ self.dropout = torch.nn.Dropout(dropout)
235
+ self.conv2 = Pseudo3DConv(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
236
+
237
+ self.nonlinearity = nn.SiLU()
238
+
239
+ self.upsample = self.downsample = None
240
+ if self.up:
241
+ self.upsample = Upsample2D(in_channels, use_conv=False)
242
+ elif self.down:
243
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
244
+
245
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
246
+
247
+ self.conv_shortcut = None
248
+ if self.use_in_shortcut:
249
+ self.conv_shortcut = Pseudo3DConv(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
250
+
251
+ def forward(self, input_tensor, temb):
252
+ hidden_states = input_tensor
253
+
254
+ hidden_states = self.norm1(hidden_states)
255
+ hidden_states = self.nonlinearity(hidden_states)
256
+
257
+ if self.upsample is not None:
258
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
259
+ if hidden_states.shape[0] >= 64:
260
+ input_tensor = input_tensor.contiguous()
261
+ hidden_states = hidden_states.contiguous()
262
+ input_tensor = self.upsample(input_tensor)
263
+ hidden_states = self.upsample(hidden_states)
264
+ elif self.downsample is not None:
265
+ input_tensor = self.downsample(input_tensor)
266
+ hidden_states = self.downsample(hidden_states)
267
+
268
+ hidden_states = self.conv1(hidden_states)
269
+
270
+ if temb is not None:
271
+ b, c, *_, h, w = hidden_states.shape
272
+ is_video = hidden_states.ndim == 5
273
+ if is_video:
274
+ b, c, f, h, w = hidden_states.shape
275
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) c h w')
276
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
277
+ hidden_states = hidden_states + temb.repeat_interleave(f, 0)
278
+ hidden_states = rearrange(hidden_states, '(b f) c h w -> b c f h w', b=b)
279
+ else:
280
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
281
+ hidden_states = hidden_states + temb
282
+
283
+ hidden_states = self.norm2(hidden_states)
284
+ hidden_states = self.nonlinearity(hidden_states)
285
+
286
+ hidden_states = self.dropout(hidden_states)
287
+ hidden_states = self.conv2(hidden_states)
288
+
289
+ if self.conv_shortcut is not None:
290
+ input_tensor = self.conv_shortcut(input_tensor)
291
+
292
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
293
+
294
+ return output_tensor
295
+
makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_blocks.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional
2
+ import torch
3
+ from torch import nn
4
+
5
+ from torch_attention_pseudo3d import TransformerPseudo3DModel
6
+ from torch_resnet_pseudo3d import Downsample2D, ResnetBlockPseudo3D, Upsample2D
7
+
8
+
9
+ class UNetMidBlock2DCrossAttn(nn.Module):
10
+ def __init__(self,
11
+ in_channels: int,
12
+ temb_channels: int,
13
+ dropout: float = 0.0,
14
+ num_layers: int = 1,
15
+ resnet_eps: float = 1e-6,
16
+ resnet_time_scale_shift: str = "default",
17
+ resnet_act_fn: str = "swish",
18
+ resnet_groups: Optional[int] = 32,
19
+ resnet_pre_norm: bool = True,
20
+ attn_num_head_channels: int = 1,
21
+ attention_type: str = "default",
22
+ output_scale_factor: float =1.0,
23
+ cross_attention_dim: int = 1280,
24
+ **kwargs
25
+ ) -> None:
26
+ super().__init__()
27
+
28
+ self.attention_type = attention_type
29
+ self.attn_num_head_channels = attn_num_head_channels
30
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
31
+
32
+ # there is always at least one resnet
33
+ resnets = [
34
+ ResnetBlockPseudo3D(
35
+ in_channels = in_channels,
36
+ out_channels = in_channels,
37
+ temb_channels = temb_channels,
38
+ eps = resnet_eps,
39
+ groups = resnet_groups,
40
+ dropout = dropout,
41
+ time_embedding_norm = resnet_time_scale_shift,
42
+ #non_linearity = resnet_act_fn,
43
+ output_scale_factor = output_scale_factor,
44
+ pre_norm = resnet_pre_norm
45
+ )
46
+ ]
47
+ attentions = []
48
+
49
+ for _ in range(num_layers):
50
+ attentions.append(
51
+ TransformerPseudo3DModel(
52
+ in_channels = in_channels,
53
+ num_attention_heads = attn_num_head_channels,
54
+ attention_head_dim = in_channels // attn_num_head_channels,
55
+ num_layers = 1,
56
+ cross_attention_dim = cross_attention_dim,
57
+ norm_num_groups = resnet_groups
58
+ )
59
+ )
60
+ resnets.append(
61
+ ResnetBlockPseudo3D(
62
+ in_channels = in_channels,
63
+ out_channels = in_channels,
64
+ temb_channels = temb_channels,
65
+ eps = resnet_eps,
66
+ groups = resnet_groups,
67
+ dropout = dropout,
68
+ time_embedding_norm = resnet_time_scale_shift,
69
+ #non_linearity = resnet_act_fn,
70
+ output_scale_factor = output_scale_factor,
71
+ pre_norm = resnet_pre_norm
72
+ )
73
+ )
74
+
75
+ self.attentions = nn.ModuleList(attentions)
76
+ self.resnets = nn.ModuleList(resnets)
77
+
78
+ def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
79
+ hidden_states = self.resnets[0](hidden_states, temb)
80
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
81
+ hidden_states = attn(hidden_states, encoder_hidden_states).sample
82
+ hidden_states = resnet(hidden_states, temb)
83
+
84
+ return hidden_states
85
+
86
+
87
+ class CrossAttnDownBlock2D(nn.Module):
88
+ def __init__(self,
89
+ in_channels: int,
90
+ out_channels: int,
91
+ temb_channels: int,
92
+ dropout: float = 0.0,
93
+ num_layers: int = 1,
94
+ resnet_eps: float = 1e-6,
95
+ resnet_time_scale_shift: str = "default",
96
+ resnet_act_fn: str = "swish",
97
+ resnet_groups: int = 32,
98
+ resnet_pre_norm: bool = True,
99
+ attn_num_head_channels: int = 1,
100
+ cross_attention_dim: int = 1280,
101
+ attention_type: str = "default",
102
+ output_scale_factor: float = 1.0,
103
+ downsample_padding: int = 1,
104
+ add_downsample: bool = True
105
+ ):
106
+ super().__init__()
107
+ resnets = []
108
+ attentions = []
109
+
110
+ self.attention_type = attention_type
111
+ self.attn_num_head_channels = attn_num_head_channels
112
+
113
+ for i in range(num_layers):
114
+ in_channels = in_channels if i == 0 else out_channels
115
+ resnets.append(
116
+ ResnetBlockPseudo3D(
117
+ in_channels = in_channels,
118
+ out_channels = out_channels,
119
+ temb_channels = temb_channels,
120
+ eps = resnet_eps,
121
+ groups = resnet_groups,
122
+ dropout = dropout,
123
+ time_embedding_norm = resnet_time_scale_shift,
124
+ #non_linearity = resnet_act_fn,
125
+ output_scale_factor = output_scale_factor,
126
+ pre_norm = resnet_pre_norm
127
+ )
128
+ )
129
+ attentions.append(
130
+ TransformerPseudo3DModel(
131
+ in_channels = out_channels,
132
+ num_attention_heads = attn_num_head_channels,
133
+ attention_head_dim = out_channels // attn_num_head_channels,
134
+ num_layers = 1,
135
+ cross_attention_dim = cross_attention_dim,
136
+ norm_num_groups = resnet_groups
137
+ )
138
+ )
139
+ self.attentions = nn.ModuleList(attentions)
140
+ self.resnets = nn.ModuleList(resnets)
141
+
142
+ if add_downsample:
143
+ self.downsamplers = nn.ModuleList(
144
+ [
145
+ Downsample2D(
146
+ out_channels,
147
+ use_conv = True,
148
+ out_channels = out_channels,
149
+ padding = downsample_padding,
150
+ name = "op"
151
+ )
152
+ ]
153
+ )
154
+ else:
155
+ self.downsamplers = None
156
+
157
+ def forward(self, hidden_states, temb = None, encoder_hidden_states = None):
158
+ output_states = ()
159
+
160
+ for resnet, attn in zip(self.resnets, self.attentions):
161
+ hidden_states = resnet(hidden_states, temb)
162
+ hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
163
+
164
+ output_states += (hidden_states,)
165
+
166
+ if self.downsamplers is not None:
167
+ for downsampler in self.downsamplers:
168
+ hidden_states = downsampler(hidden_states)
169
+
170
+ output_states += (hidden_states,)
171
+
172
+ return hidden_states, output_states
173
+
174
+
175
+ class DownBlock2D(nn.Module):
176
+ def __init__(self,
177
+ in_channels: int,
178
+ out_channels: int,
179
+ temb_channels: int,
180
+ dropout: float = 0.0,
181
+ num_layers: int = 1,
182
+ resnet_eps: float = 1e-6,
183
+ resnet_time_scale_shift: str = "default",
184
+ resnet_act_fn: str = "swish",
185
+ resnet_groups: int = 32,
186
+ resnet_pre_norm: bool = True,
187
+ output_scale_factor: float = 1.0,
188
+ add_downsample: bool = True,
189
+ downsample_padding: int = 1
190
+ ) -> None:
191
+ super().__init__()
192
+ resnets = []
193
+
194
+ for i in range(num_layers):
195
+ in_channels = in_channels if i == 0 else out_channels
196
+ resnets.append(
197
+ ResnetBlockPseudo3D(
198
+ in_channels = in_channels,
199
+ out_channels = out_channels,
200
+ temb_channels = temb_channels,
201
+ eps = resnet_eps,
202
+ groups = resnet_groups,
203
+ dropout = dropout,
204
+ time_embedding_norm = resnet_time_scale_shift,
205
+ #non_linearity = resnet_act_fn,
206
+ output_scale_factor = output_scale_factor,
207
+ pre_norm = resnet_pre_norm
208
+ )
209
+ )
210
+
211
+ self.resnets = nn.ModuleList(resnets)
212
+
213
+ if add_downsample:
214
+ self.downsamplers = nn.ModuleList(
215
+ [
216
+ Downsample2D(
217
+ out_channels,
218
+ use_conv = True,
219
+ out_channels = out_channels,
220
+ padding = downsample_padding,
221
+ name = "op"
222
+ )
223
+ ]
224
+ )
225
+ else:
226
+ self.downsamplers = None
227
+
228
+
229
+ def forward(self, hidden_states, temb = None):
230
+ output_states = ()
231
+
232
+ for resnet in self.resnets:
233
+ hidden_states = resnet(hidden_states, temb)
234
+
235
+ output_states += (hidden_states,)
236
+
237
+ if self.downsamplers is not None:
238
+ for downsampler in self.downsamplers:
239
+ hidden_states = downsampler(hidden_states)
240
+
241
+ output_states += (hidden_states,)
242
+
243
+ return hidden_states, output_states
244
+
245
+
246
+ class CrossAttnUpBlock2D(nn.Module):
247
+ def __init__(self,
248
+ in_channels: int,
249
+ out_channels: int,
250
+ prev_output_channel: int,
251
+ temb_channels: int,
252
+ dropout: float = 0.0,
253
+ num_layers: int = 1,
254
+ resnet_eps: float = 1e-6,
255
+ resnet_time_scale_shift: str = "default",
256
+ resnet_act_fn: str = "swish",
257
+ resnet_groups: int = 32,
258
+ resnet_pre_norm: bool = True,
259
+ attn_num_head_channels: int = 1,
260
+ cross_attention_dim: int = 1280,
261
+ attention_type: str = "default",
262
+ output_scale_factor: float = 1.0,
263
+ add_upsample: bool = True
264
+ ) -> None:
265
+ super().__init__()
266
+ resnets = []
267
+ attentions = []
268
+
269
+ self.attention_type = attention_type
270
+ self.attn_num_head_channels = attn_num_head_channels
271
+
272
+ for i in range(num_layers):
273
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
274
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
275
+
276
+ resnets.append(
277
+ ResnetBlockPseudo3D(
278
+ in_channels = resnet_in_channels + res_skip_channels,
279
+ out_channels = out_channels,
280
+ temb_channels = temb_channels,
281
+ eps = resnet_eps,
282
+ groups = resnet_groups,
283
+ dropout = dropout,
284
+ time_embedding_norm = resnet_time_scale_shift,
285
+ #non_linearity = resnet_act_fn,
286
+ output_scale_factor = output_scale_factor,
287
+ pre_norm = resnet_pre_norm
288
+ )
289
+ )
290
+ attentions.append(
291
+ TransformerPseudo3DModel(
292
+ in_channels = out_channels,
293
+ num_attention_heads = attn_num_head_channels,
294
+ attention_head_dim = out_channels // attn_num_head_channels,
295
+ num_layers = 1,
296
+ cross_attention_dim = cross_attention_dim,
297
+ norm_num_groups = resnet_groups
298
+ )
299
+ )
300
+ self.attentions = nn.ModuleList(attentions)
301
+ self.resnets = nn.ModuleList(resnets)
302
+
303
+ if add_upsample:
304
+ self.upsamplers = nn.ModuleList([
305
+ Upsample2D(
306
+ out_channels,
307
+ use_conv = True,
308
+ out_channels = out_channels
309
+ )
310
+ ])
311
+ else:
312
+ self.upsamplers = None
313
+
314
+ def forward(self,
315
+ hidden_states,
316
+ res_hidden_states_tuple,
317
+ temb = None,
318
+ encoder_hidden_states = None,
319
+ upsample_size = None
320
+ ):
321
+ for resnet, attn in zip(self.resnets, self.attentions):
322
+ # pop res hidden states
323
+ res_hidden_states = res_hidden_states_tuple[-1]
324
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
325
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
326
+ hidden_states = resnet(hidden_states, temb)
327
+ hidden_states = attn(hidden_states, encoder_hidden_states = encoder_hidden_states).sample
328
+
329
+ if self.upsamplers is not None:
330
+ for upsampler in self.upsamplers:
331
+ hidden_states = upsampler(hidden_states, upsample_size)
332
+
333
+ return hidden_states
334
+
335
+
336
+ class UpBlock2D(nn.Module):
337
+ def __init__(self,
338
+ in_channels: int,
339
+ prev_output_channel: int,
340
+ out_channels: int,
341
+ temb_channels: int,
342
+ dropout: float = 0.0,
343
+ num_layers: int = 1,
344
+ resnet_eps: float = 1e-6,
345
+ resnet_time_scale_shift: str = "default",
346
+ resnet_act_fn: str = "swish",
347
+ resnet_groups: int = 32,
348
+ resnet_pre_norm: bool = True,
349
+ output_scale_factor: float = 1.0,
350
+ add_upsample: bool = True
351
+ ) -> None:
352
+ super().__init__()
353
+ resnets = []
354
+
355
+ for i in range(num_layers):
356
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
357
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
358
+
359
+ resnets.append(
360
+ ResnetBlockPseudo3D(
361
+ in_channels = resnet_in_channels + res_skip_channels,
362
+ out_channels = out_channels,
363
+ temb_channels = temb_channels,
364
+ eps = resnet_eps,
365
+ groups = resnet_groups,
366
+ dropout = dropout,
367
+ time_embedding_norm = resnet_time_scale_shift,
368
+ #non_linearity = resnet_act_fn,
369
+ output_scale_factor = output_scale_factor,
370
+ pre_norm = resnet_pre_norm
371
+ )
372
+ )
373
+
374
+ self.resnets = nn.ModuleList(resnets)
375
+
376
+ if add_upsample:
377
+ self.upsamplers = nn.ModuleList([
378
+ Upsample2D(
379
+ out_channels,
380
+ use_conv = True,
381
+ out_channels = out_channels
382
+ )
383
+ ])
384
+ else:
385
+ self.upsamplers = None
386
+
387
+
388
+ def forward(self, hidden_states, res_hidden_states_tuple, temb = None, upsample_size = None):
389
+ for resnet in self.resnets:
390
+ # pop res hidden states
391
+ res_hidden_states = res_hidden_states_tuple[-1]
392
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
393
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
394
+ hidden_states = resnet(hidden_states, temb)
395
+
396
+ if self.upsamplers is not None:
397
+ for upsampler in self.upsamplers:
398
+ hidden_states = upsampler(hidden_states, upsample_size)
399
+
400
+ return hidden_states
401
+
402
+
403
+ def get_down_block(
404
+ down_block_type: str,
405
+ num_layers: int,
406
+ in_channels: int,
407
+ out_channels: int,
408
+ temb_channels: int,
409
+ add_downsample: bool,
410
+ resnet_eps: float,
411
+ resnet_act_fn: str,
412
+ attn_num_head_channels: int,
413
+ resnet_groups: Optional[int] = None,
414
+ cross_attention_dim: Optional[int] = None,
415
+ downsample_padding: Optional[int] = None,
416
+ ) -> Union[DownBlock2D, CrossAttnDownBlock2D]:
417
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
418
+ if down_block_type == "DownBlock2D":
419
+ return DownBlock2D(
420
+ num_layers = num_layers,
421
+ in_channels = in_channels,
422
+ out_channels = out_channels,
423
+ temb_channels = temb_channels,
424
+ add_downsample = add_downsample,
425
+ resnet_eps = resnet_eps,
426
+ resnet_act_fn = resnet_act_fn,
427
+ resnet_groups = resnet_groups,
428
+ downsample_padding = downsample_padding
429
+ )
430
+ elif down_block_type == "CrossAttnDownBlock2D":
431
+ if cross_attention_dim is None:
432
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
433
+ return CrossAttnDownBlock2D(
434
+ num_layers = num_layers,
435
+ in_channels = in_channels,
436
+ out_channels = out_channels,
437
+ temb_channels = temb_channels,
438
+ add_downsample = add_downsample,
439
+ resnet_eps = resnet_eps,
440
+ resnet_act_fn = resnet_act_fn,
441
+ resnet_groups = resnet_groups,
442
+ downsample_padding = downsample_padding,
443
+ cross_attention_dim = cross_attention_dim,
444
+ attn_num_head_channels = attn_num_head_channels
445
+ )
446
+ raise ValueError(f"{down_block_type} does not exist.")
447
+
448
+
449
+ def get_up_block(
450
+ up_block_type: str,
451
+ num_layers,
452
+ in_channels,
453
+ out_channels,
454
+ prev_output_channel,
455
+ temb_channels,
456
+ add_upsample,
457
+ resnet_eps,
458
+ resnet_act_fn,
459
+ attn_num_head_channels,
460
+ resnet_groups = None,
461
+ cross_attention_dim = None,
462
+ ) -> Union[UpBlock2D, CrossAttnUpBlock2D]:
463
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
464
+ if up_block_type == "UpBlock2D":
465
+ return UpBlock2D(
466
+ num_layers = num_layers,
467
+ in_channels = in_channels,
468
+ out_channels = out_channels,
469
+ prev_output_channel = prev_output_channel,
470
+ temb_channels = temb_channels,
471
+ add_upsample = add_upsample,
472
+ resnet_eps = resnet_eps,
473
+ resnet_act_fn = resnet_act_fn,
474
+ resnet_groups = resnet_groups
475
+ )
476
+ elif up_block_type == "CrossAttnUpBlock2D":
477
+ if cross_attention_dim is None:
478
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
479
+ return CrossAttnUpBlock2D(
480
+ num_layers = num_layers,
481
+ in_channels = in_channels,
482
+ out_channels = out_channels,
483
+ prev_output_channel = prev_output_channel,
484
+ temb_channels = temb_channels,
485
+ add_upsample = add_upsample,
486
+ resnet_eps = resnet_eps,
487
+ resnet_act_fn = resnet_act_fn,
488
+ resnet_groups = resnet_groups,
489
+ cross_attention_dim = cross_attention_dim,
490
+ attn_num_head_channels = attn_num_head_channels
491
+ )
492
+ raise ValueError(f"{up_block_type} does not exist.")
493
+
makeavid_sd/makeavid_sd/torch_impl/torch_unet_pseudo3d_condition.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn as nn
6
+
7
+ from torch_embeddings import TimestepEmbedding, Timesteps
8
+ from torch_unet_pseudo3d_blocks import (
9
+ UNetMidBlock2DCrossAttn,
10
+ get_down_block,
11
+ get_up_block,
12
+ )
13
+
14
+ from torch_resnet_pseudo3d import Pseudo3DConv
15
+
16
+ class UNetPseudo3DConditionOutput:
17
+ sample: torch.FloatTensor
18
+ def __init__(self, sample: torch.FloatTensor) -> None:
19
+ self.sample = sample
20
+
21
+
22
+ class UNetPseudo3DConditionModel(nn.Module):
23
+ def __init__(self,
24
+ sample_size: Optional[int] = None,
25
+ in_channels: int = 9,
26
+ out_channels: int = 4,
27
+ flip_sin_to_cos: bool = True,
28
+ freq_shift: int = 0,
29
+ down_block_types: Tuple[str] = (
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "DownBlock2D",
34
+ ),
35
+ up_block_types: Tuple[str] = (
36
+ "UpBlock2D",
37
+ "CrossAttnUpBlock2D",
38
+ "CrossAttnUpBlock2D",
39
+ "CrossAttnUpBlock2D"
40
+ ),
41
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
42
+ layers_per_block: int = 2,
43
+ downsample_padding: int = 1,
44
+ mid_block_scale_factor: float = 1,
45
+ act_fn: str = "silu",
46
+ norm_num_groups: int = 32,
47
+ norm_eps: float = 1e-5,
48
+ cross_attention_dim: int = 768,
49
+ attention_head_dim: int = 8,
50
+ **kwargs
51
+ ) -> None:
52
+ super().__init__()
53
+ self.dtype = torch.float32
54
+ self.sample_size = sample_size
55
+ time_embed_dim = block_out_channels[0] * 4
56
+
57
+ # input
58
+ self.conv_in = Pseudo3DConv(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
59
+
60
+ # time
61
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
62
+ timestep_input_dim = block_out_channels[0]
63
+
64
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
65
+
66
+ self.down_blocks = nn.ModuleList([])
67
+ self.mid_block = None
68
+ self.up_blocks = nn.ModuleList([])
69
+
70
+ # down
71
+ output_channel = block_out_channels[0]
72
+ for i, down_block_type in enumerate(down_block_types):
73
+ input_channel = output_channel
74
+ output_channel = block_out_channels[i]
75
+ is_final_block = i == len(block_out_channels) - 1
76
+
77
+ down_block = get_down_block(
78
+ down_block_type,
79
+ num_layers = layers_per_block,
80
+ in_channels = input_channel,
81
+ out_channels = output_channel,
82
+ temb_channels = time_embed_dim,
83
+ add_downsample = not is_final_block,
84
+ resnet_eps = norm_eps,
85
+ resnet_act_fn = act_fn,
86
+ resnet_groups = norm_num_groups,
87
+ cross_attention_dim = cross_attention_dim,
88
+ attn_num_head_channels = attention_head_dim,
89
+ downsample_padding = downsample_padding
90
+ )
91
+ self.down_blocks.append(down_block)
92
+
93
+ # mid
94
+ self.mid_block = UNetMidBlock2DCrossAttn(
95
+ in_channels = block_out_channels[-1],
96
+ temb_channels = time_embed_dim,
97
+ resnet_eps = norm_eps,
98
+ resnet_act_fn = act_fn,
99
+ output_scale_factor = mid_block_scale_factor,
100
+ resnet_time_scale_shift = "default",
101
+ cross_attention_dim = cross_attention_dim,
102
+ attn_num_head_channels = attention_head_dim,
103
+ resnet_groups = norm_num_groups
104
+ )
105
+
106
+ # count how many layers upsample the images
107
+ self.num_upsamplers = 0
108
+
109
+ # up
110
+ reversed_block_out_channels = list(reversed(block_out_channels))
111
+ output_channel = reversed_block_out_channels[0]
112
+ for i, up_block_type in enumerate(up_block_types):
113
+ is_final_block = i == len(block_out_channels) - 1
114
+
115
+ prev_output_channel = output_channel
116
+ output_channel = reversed_block_out_channels[i]
117
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
118
+
119
+ # add upsample block for all BUT final layer
120
+ if not is_final_block:
121
+ add_upsample = True
122
+ self.num_upsamplers += 1
123
+ else:
124
+ add_upsample = False
125
+
126
+ up_block = get_up_block(
127
+ up_block_type,
128
+ num_layers = layers_per_block + 1,
129
+ in_channels = input_channel,
130
+ out_channels = output_channel,
131
+ prev_output_channel = prev_output_channel,
132
+ temb_channels = time_embed_dim,
133
+ add_upsample = add_upsample,
134
+ resnet_eps = norm_eps,
135
+ resnet_act_fn = act_fn,
136
+ resnet_groups = norm_num_groups,
137
+ cross_attention_dim = cross_attention_dim,
138
+ attn_num_head_channels = attention_head_dim
139
+ )
140
+ self.up_blocks.append(up_block)
141
+ prev_output_channel = output_channel
142
+
143
+ # out
144
+ self.conv_norm_out = nn.GroupNorm(
145
+ num_channels = block_out_channels[0],
146
+ num_groups = norm_num_groups,
147
+ eps = norm_eps
148
+ )
149
+ self.conv_act = nn.SiLU()
150
+ self.conv_out = Pseudo3DConv(block_out_channels[0], out_channels, 3, padding = 1)
151
+
152
+
153
+ def forward(
154
+ self,
155
+ sample: torch.FloatTensor,
156
+ timesteps: Union[torch.Tensor, float, int],
157
+ encoder_hidden_states: torch.Tensor
158
+ ) -> Union[UNetPseudo3DConditionOutput, Tuple]:
159
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
160
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
161
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
162
+ # on the fly if necessary.
163
+ default_overall_up_factor = 2**self.num_upsamplers
164
+
165
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
166
+ forward_upsample_size = False
167
+ upsample_size = None
168
+
169
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
170
+ forward_upsample_size = True
171
+
172
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
173
+ timesteps = timesteps.expand(sample.shape[0])
174
+
175
+ t_emb = self.time_proj(timesteps)
176
+
177
+ # timesteps does not contain any weights and will always return f32 tensors
178
+ # but time_embedding might actually be running in fp16. so we need to cast here.
179
+ # there might be better ways to encapsulate this.
180
+ t_emb = t_emb.to(dtype=self.dtype)
181
+ emb = self.time_embedding(t_emb)
182
+
183
+ # 2. pre-process
184
+ sample = self.conv_in(sample)
185
+
186
+ # 3. down
187
+ down_block_res_samples = (sample,)
188
+ for downsample_block in self.down_blocks:
189
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
190
+ sample, res_samples = downsample_block(
191
+ hidden_states = sample,
192
+ temb = emb,
193
+ encoder_hidden_states = encoder_hidden_states,
194
+ )
195
+ else:
196
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
197
+
198
+ down_block_res_samples += res_samples
199
+
200
+ # 4. mid
201
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
202
+
203
+ # 5. up
204
+ for i, upsample_block in enumerate(self.up_blocks):
205
+ is_final_block = i == len(self.up_blocks) - 1
206
+
207
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
208
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
209
+
210
+ # if we have not reached the final block and need to forward the
211
+ # upsample size, we do it here
212
+ if not is_final_block and forward_upsample_size:
213
+ upsample_size = down_block_res_samples[-1].shape[2:]
214
+
215
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
216
+ sample = upsample_block(
217
+ hidden_states = sample,
218
+ temb = emb,
219
+ res_hidden_states_tuple = res_samples,
220
+ encoder_hidden_states = encoder_hidden_states,
221
+ upsample_size = upsample_size,
222
+ )
223
+ else:
224
+ sample = upsample_block(
225
+ hidden_states = sample,
226
+ temb = emb,
227
+ res_hidden_states_tuple = res_samples,
228
+ upsample_size = upsample_size
229
+ )
230
+ # 6. post-process
231
+ sample = self.conv_norm_out(sample)
232
+ sample = self.conv_act(sample)
233
+ sample = self.conv_out(sample)
234
+
235
+ return UNetPseudo3DConditionOutput(sample = sample)
makeavid_sd/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ torch_xla
makeavid_sd/setup.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ setup(
3
+ name = 'makeavid_sd',
4
+ version = '0.1.0',
5
+ description = 'makeavid sd',
6
+ author = 'Lopho',
7
+ author_email = '[email protected]',
8
+ platforms = ['any'],
9
+ license = 'GNU Affero General Public License v3',
10
+ url = 'http://github.com/lopho/makeavid-sd-tpu'
11
+ )
makeavid_sd/trainer_xla.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PJRT_DEVICE'] = 'TPU'
3
+
4
+ from tqdm.auto import tqdm
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from torch_xla.core import xla_model
8
+ from diffusers import UNetPseudo3DConditionModel
9
+ from dataset import load_dataset
10
+
11
+
12
+ class TempoTrainerXLA:
13
+ def __init__(self,
14
+ pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse',
15
+ lr: float = 1e-4,
16
+ dtype: torch.dtype = torch.float32,
17
+ ) -> None:
18
+ self.dtype = dtype
19
+ self.device: torch.device = xla_model.xla_device(0)
20
+ unet: UNetPseudo3DConditionModel = UNetPseudo3DConditionModel.from_pretrained(
21
+ pretrained,
22
+ subfolder = 'unet'
23
+ ).to(dtype = dtype, memory_format = torch.contiguous_format)
24
+ unfreeze_all: bool = False
25
+ unet = unet.train()
26
+ if not unfreeze_all:
27
+ unet.requires_grad_(False)
28
+ for name, param in unet.named_parameters():
29
+ if 'temporal_conv' in name:
30
+ param.requires_grad_(True)
31
+ for block in [*unet.down_blocks, unet.mid_block, *unet.up_blocks]:
32
+ if hasattr(block, 'attentions') and block.attentions is not None:
33
+ for attn_block in block.attentions:
34
+ for transformer_block in attn_block.transformer_blocks:
35
+ transformer_block.requires_grad_(False)
36
+ transformer_block.attn_temporal.requires_grad_(True)
37
+ transformer_block.norm_temporal.requires_grad_(True)
38
+ else:
39
+ unet.requires_grad_(True)
40
+ self.model: UNetPseudo3DConditionModel = unet.to(device = self.device)
41
+ #self.model = torch.compile(self.model, backend = 'aot_torchxla_trace_once')
42
+ self.params = lambda: filter(lambda p: p.requires_grad, self.model.parameters())
43
+ self.optim: torch.optim.Optimizer = torch.optim.AdamW(self.params(), lr = lr)
44
+ def lr_warmup(warmup_steps: int = 0):
45
+ def lambda_lr(step: int) -> float:
46
+ if step < warmup_steps:
47
+ return step / warmup_steps
48
+ else:
49
+ return 1.0
50
+ return lambda_lr
51
+ self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optim, lr_lambda = lr_warmup(warmup_steps = 60), last_epoch = -1)
52
+
53
+ @torch.no_grad()
54
+ def train(self, dataloader: DataLoader, epochs: int = 1, log_every: int = 1, save_every: int = 1000) -> None:
55
+ # 'latent_model_input'
56
+ # 'encoder_hidden_states'
57
+ # 'timesteps'
58
+ # 'noise'
59
+ global_step: int = 0
60
+ for epoch in range(epochs):
61
+ pbar = tqdm(dataloader, dynamic_ncols = True, smoothing = 0.01)
62
+ for b in pbar:
63
+ latent_model_input: torch.Tensor = b['latent_model_input'].to(device = self.device)
64
+ encoder_hidden_states: torch.Tensor = b['encoder_hidden_states'].to(device = self.device)
65
+ timesteps: torch.Tensor = b['timesteps'].to(device = self.device)
66
+ noise: torch.Tensor = b['noise'].to(device = self.device)
67
+ with torch.enable_grad():
68
+ self.optim.zero_grad(set_to_none = True)
69
+ y = self.model(latent_model_input, timesteps, encoder_hidden_states).sample
70
+ loss = torch.nn.functional.mse_loss(noise, y)
71
+ loss.backward()
72
+ self.optim.step()
73
+ self.scheduler.step()
74
+ xla_model.mark_step()
75
+ if global_step % log_every == 0:
76
+ pbar.set_postfix({ 'loss': loss.detach().item(), 'epoch': epoch })
77
+
78
+ def main():
79
+ pretrained: str = 'lxj616/make-a-stable-diffusion-video-timelapse'
80
+ dataset_path: str = './storage/dataset/tempofunk'
81
+ dtype: torch.dtype = torch.bfloat16
82
+ trainer = TempoTrainerXLA(
83
+ pretrained = pretrained,
84
+ lr = 1e-5,
85
+ dtype = dtype
86
+ )
87
+ dataloader: DataLoader = load_dataset(
88
+ dataset_path = dataset_path,
89
+ pretrained = pretrained,
90
+ batch_size = 1,
91
+ num_frames = 10,
92
+ num_workers = 1,
93
+ dtype = dtype
94
+ )
95
+ trainer.train(
96
+ dataloader = dataloader,
97
+ epochs = 1000,
98
+ log_every = 1,
99
+ save_every = 1000
100
+ )
101
+
102
+ if __name__ == '__main__':
103
+ main()
104
+