recursionaut
commited on
Commit
•
6ded986
1
Parent(s):
c1a2b2a
testing files upload (#7)
Browse files- testing files upload (f489a598b0d6a46d9a99c819210b220269d9b29b)
- .gitignore +32 -0
- LICENSE +399 -0
- MODELCARD.md +128 -0
- README.md +34 -120
- config.yaml +16 -0
- generate_reconstructions.ipynb +0 -0
- huggingface_mae.py +293 -0
- loss.py +59 -0
- mae_modules.py +273 -0
- mae_utils.py +70 -0
- masking.py +51 -0
- normalizer.py +7 -0
- pyproject.toml +34 -0
- sample/AA41_s1_1.jp2 +0 -0
- sample/AA41_s1_2.jp2 +0 -0
- sample/AA41_s1_3.jp2 +0 -0
- sample/AA41_s1_4.jp2 +0 -0
- sample/AA41_s1_5.jp2 +0 -0
- sample/AA41_s1_6.jp2 +0 -0
- test_huggingface_mae.py +32 -0
- vit.py +309 -0
- vit_encoder.py +61 -0
.gitignore
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# model artifacts
|
30 |
+
*.pickle
|
31 |
+
*.ckpt
|
32 |
+
*.safetensors
|
LICENSE
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
MODELCARD.md
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Phenom CA-MAE-S/16
|
7 |
+
|
8 |
+
Channel-agnostic image encoding model designed for microscopy image featurization.
|
9 |
+
The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
|
17 |
+
1. RxRx3
|
18 |
+
2. JUMP-CP overexpression
|
19 |
+
3. JUMP-CP gene-knockouts
|
20 |
+
|
21 |
+
- **Developed, funded, and shared by:** Recursion
|
22 |
+
- **Model type:** Vision transformer CA-MAE
|
23 |
+
- **Image modality:** Optimized for microscopy images from the CellPainting assay
|
24 |
+
- **License:**
|
25 |
+
|
26 |
+
|
27 |
+
### Model Sources
|
28 |
+
|
29 |
+
- **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
|
30 |
+
- **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
|
31 |
+
|
32 |
+
|
33 |
+
## Uses
|
34 |
+
|
35 |
+
NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
|
36 |
+
|
37 |
+
1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
|
38 |
+
2. Transform all the embeddings with that PCA kernel,
|
39 |
+
3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
|
40 |
+
|
41 |
+
### Direct Use
|
42 |
+
|
43 |
+
- Create biologically useful embeddings of microscopy images
|
44 |
+
- Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
|
45 |
+
- Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
|
46 |
+
|
47 |
+
### Downstream Use
|
48 |
+
|
49 |
+
- A determined ML expert could fine-tune the encoder for downstream tasks such as classification
|
50 |
+
|
51 |
+
### Out-of-Scope Use
|
52 |
+
|
53 |
+
- Unlikely to be especially performant on brightfield microscopy images
|
54 |
+
- Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
|
55 |
+
|
56 |
+
## Bias, Risks, and Limitations
|
57 |
+
|
58 |
+
- Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
|
59 |
+
|
60 |
+
## How to Get Started with the Model
|
61 |
+
|
62 |
+
You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
|
63 |
+
|
64 |
+
```python
|
65 |
+
import pytest
|
66 |
+
import torch
|
67 |
+
|
68 |
+
from huggingface_mae import MAEModel
|
69 |
+
|
70 |
+
huggingface_phenombeta_model_dir = "."
|
71 |
+
# huggingface_modelpath = "recursionpharma/test-pb-model"
|
72 |
+
|
73 |
+
|
74 |
+
@pytest.fixture
|
75 |
+
def huggingface_model():
|
76 |
+
# Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
|
77 |
+
# huggingface-cli download recursionpharma/test-pb-model --local-dir=.
|
78 |
+
huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
|
79 |
+
huggingface_model.eval()
|
80 |
+
return huggingface_model
|
81 |
+
|
82 |
+
|
83 |
+
@pytest.mark.parametrize("C", [1, 4, 6, 11])
|
84 |
+
@pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
|
85 |
+
def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
|
86 |
+
example_input_array = torch.randint(
|
87 |
+
low=0,
|
88 |
+
high=255,
|
89 |
+
size=(2, C, 256, 256),
|
90 |
+
dtype=torch.uint8,
|
91 |
+
device=huggingface_model.device,
|
92 |
+
)
|
93 |
+
huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
|
94 |
+
embeddings = huggingface_model.predict(example_input_array)
|
95 |
+
expected_output_dim = 384 * C if return_channelwise_embeddings else 384
|
96 |
+
assert embeddings.shape == (2, expected_output_dim)
|
97 |
+
```
|
98 |
+
|
99 |
+
|
100 |
+
## Training, evaluation and testing details
|
101 |
+
|
102 |
+
See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above.
|
103 |
+
|
104 |
+
|
105 |
+
## Environmental Impact
|
106 |
+
|
107 |
+
- **Hardware Type:** Nvidia H100 Hopper nodes
|
108 |
+
- **Hours used:** 400
|
109 |
+
- **Cloud Provider:** private cloud
|
110 |
+
- **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal)
|
111 |
+
|
112 |
+
**BibTeX:**
|
113 |
+
|
114 |
+
```TeX
|
115 |
+
@inproceedings{kraus2024masked,
|
116 |
+
title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
|
117 |
+
author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
|
118 |
+
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
119 |
+
pages={11757--11768},
|
120 |
+
year={2024}
|
121 |
+
}
|
122 |
+
```
|
123 |
+
|
124 |
+
## Model Card Contact
|
125 |
+
|
126 |
+
- Kian Kenyon-Dean: [email protected]
|
127 |
+
- Oren Kraus: [email protected]
|
128 |
+
- Or, email: [email protected]
|
README.md
CHANGED
@@ -1,128 +1,42 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
|
8 |
-
Channel-agnostic image encoding model designed for microscopy image featurization.
|
9 |
-
The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
|
10 |
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
### Model Description
|
15 |
-
|
16 |
-
This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
|
17 |
-
1. RxRx3
|
18 |
-
2. JUMP-CP overexpression
|
19 |
-
3. JUMP-CP gene-knockouts
|
20 |
-
|
21 |
-
- **Developed, funded, and shared by:** Recursion
|
22 |
-
- **Model type:** Vision transformer CA-MAE
|
23 |
-
- **Image modality:** Optimized for microscopy images from the CellPainting assay
|
24 |
-
- **License:**
|
25 |
-
|
26 |
-
|
27 |
-
### Model Sources
|
28 |
-
|
29 |
-
- **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
|
30 |
-
- **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
|
31 |
-
|
32 |
-
|
33 |
-
## Uses
|
34 |
-
|
35 |
-
NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
|
36 |
-
|
37 |
-
1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
|
38 |
-
2. Transform all the embeddings with that PCA kernel,
|
39 |
-
3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
|
40 |
-
|
41 |
-
### Direct Use
|
42 |
-
|
43 |
-
- Create biologically useful embeddings of microscopy images
|
44 |
-
- Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
|
45 |
-
- Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
|
46 |
-
|
47 |
-
### Downstream Use
|
48 |
-
|
49 |
-
- A determined ML expert could fine-tune the encoder for downstream tasks such as classification
|
50 |
-
|
51 |
-
### Out-of-Scope Use
|
52 |
-
|
53 |
-
- Unlikely to be especially performant on brightfield microscopy images
|
54 |
-
- Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
|
55 |
-
|
56 |
-
## Bias, Risks, and Limitations
|
57 |
-
|
58 |
-
- Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
|
59 |
-
|
60 |
-
## How to Get Started with the Model
|
61 |
-
|
62 |
-
You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
|
63 |
-
|
64 |
-
```python
|
65 |
-
import pytest
|
66 |
-
import torch
|
67 |
-
|
68 |
-
from huggingface_mae import MAEModel
|
69 |
-
|
70 |
-
huggingface_phenombeta_model_dir = "models/phenom_beta_huggingface"
|
71 |
-
# huggingface_modelpath = "recursionpharma/test-pb-model"
|
72 |
-
|
73 |
-
|
74 |
-
@pytest.fixture
|
75 |
-
def huggingface_model():
|
76 |
-
# Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
|
77 |
-
# huggingface-cli download recursionpharma/test-pb-model --local-dir=models/phenom_beta_huggingface
|
78 |
-
huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
|
79 |
-
huggingface_model.eval()
|
80 |
-
return huggingface_model
|
81 |
-
|
82 |
-
|
83 |
-
@pytest.mark.parametrize("C", [1, 4, 6, 11])
|
84 |
-
@pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
|
85 |
-
def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
|
86 |
-
example_input_array = torch.randint(
|
87 |
-
low=0,
|
88 |
-
high=255,
|
89 |
-
size=(2, C, 256, 256),
|
90 |
-
dtype=torch.uint8,
|
91 |
-
device=huggingface_model.device,
|
92 |
-
)
|
93 |
-
huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
|
94 |
-
embeddings = huggingface_model.predict(example_input_array)
|
95 |
-
expected_output_dim = 384 * C if return_channelwise_embeddings else 384
|
96 |
-
assert embeddings.shape == (2, expected_output_dim)
|
97 |
```
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
|
117 |
-
author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
|
118 |
-
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
119 |
-
pages={11757--11768},
|
120 |
-
year={2024}
|
121 |
-
}
|
122 |
```
|
123 |
|
124 |
-
##
|
|
|
125 |
|
126 |
-
-
|
127 |
-
-
|
128 |
-
-
|
|
|
1 |
+
# Masked Autoencoders are Scalable Learners of Cellular Morphology
|
2 |
+
Official repo for Recursion's two recently accepted papers:
|
3 |
+
- Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology
|
4 |
+
- Paper: https://arxiv.org/abs/2404.10242
|
5 |
+
- CVPR poster page with video: https://cvpr.thecvf.com/virtual/2024/poster/31565
|
6 |
+
- Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio)
|
7 |
+
- Paper: https://arxiv.org/abs/2309.16064
|
8 |
|
9 |
+
![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
|
10 |
|
|
|
|
|
11 |
|
12 |
+
## Provided code
|
13 |
+
See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase.
|
14 |
|
15 |
+
Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
```
|
17 |
+
import timm.models.vision_transformer as vit
|
18 |
+
|
19 |
+
def vit_base_patch16_256(**kwargs):
|
20 |
+
default_kwargs = dict(
|
21 |
+
img_size=256,
|
22 |
+
in_chans=6,
|
23 |
+
num_classes=0,
|
24 |
+
fc_norm=None,
|
25 |
+
class_token=True,
|
26 |
+
drop_path_rate=0.1,
|
27 |
+
init_values=0.0001,
|
28 |
+
block_fn=vit.ParallelScalingBlock,
|
29 |
+
qkv_bias=False,
|
30 |
+
qk_norm=True,
|
31 |
+
)
|
32 |
+
for k, v in kwargs.items():
|
33 |
+
default_kwargs[k] = v
|
34 |
+
return vit.vit_base_patch16_224(**default_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
```
|
36 |
|
37 |
+
## Provided models
|
38 |
+
A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling: https://www.rxrx.ai/phenom
|
39 |
|
40 |
+
We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
|
41 |
+
- https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
|
42 |
+
- https://www.youtube.com/watch?v=Gch6bX1toB0
|
config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
loss:
|
3 |
+
_target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
|
4 |
+
reduction: none
|
5 |
+
optimizer:
|
6 |
+
_target_: timm.optim.lion.Lion
|
7 |
+
_partial_: true
|
8 |
+
lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
|
9 |
+
weight_decay: 0.05
|
10 |
+
betas: [0.9, 0.95]
|
11 |
+
lr_scheduler:
|
12 |
+
_target_: torch.optim.lr_scheduler.OneCycleLR
|
13 |
+
_partial_: true
|
14 |
+
max_lr: @lr
|
15 |
+
pct_start: 0.1
|
16 |
+
anneal_strategy: cos
|
generate_reconstructions.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
huggingface_mae.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from transformers import PretrainedConfig, PreTrainedModel
|
7 |
+
|
8 |
+
from loss import FourierLoss
|
9 |
+
from normalizer import Normalizer
|
10 |
+
from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
|
11 |
+
from mae_utils import flatten_images
|
12 |
+
from vit import (
|
13 |
+
generate_2d_sincos_pos_embeddings,
|
14 |
+
sincos_positional_encoding_vit,
|
15 |
+
vit_small_patch16_256,
|
16 |
+
)
|
17 |
+
|
18 |
+
TensorDict = Dict[str, torch.Tensor]
|
19 |
+
|
20 |
+
|
21 |
+
class MAEConfig(PretrainedConfig):
|
22 |
+
model_type = "MAE"
|
23 |
+
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
mask_ratio=0.75,
|
27 |
+
encoder=None,
|
28 |
+
decoder=None,
|
29 |
+
loss=None,
|
30 |
+
optimizer=None,
|
31 |
+
input_norm=None,
|
32 |
+
fourier_loss=None,
|
33 |
+
fourier_loss_weight=0.0,
|
34 |
+
lr_scheduler=None,
|
35 |
+
use_MAE_weight_init=False,
|
36 |
+
crop_size=-1,
|
37 |
+
mask_fourier_loss=True,
|
38 |
+
return_channelwise_embeddings=False,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super().__init__(**kwargs)
|
42 |
+
self.mask_ratio = mask_ratio
|
43 |
+
self.encoder = encoder
|
44 |
+
self.decoder = decoder
|
45 |
+
self.loss = loss
|
46 |
+
self.optimizer = optimizer
|
47 |
+
self.input_norm = input_norm
|
48 |
+
self.fourier_loss = fourier_loss
|
49 |
+
self.fourier_loss_weight = fourier_loss_weight
|
50 |
+
self.lr_scheduler = lr_scheduler
|
51 |
+
self.use_MAE_weight_init = use_MAE_weight_init
|
52 |
+
self.crop_size = crop_size
|
53 |
+
self.mask_fourier_loss = mask_fourier_loss
|
54 |
+
self.return_channelwise_embeddings = return_channelwise_embeddings
|
55 |
+
|
56 |
+
|
57 |
+
class MAEModel(PreTrainedModel):
|
58 |
+
config_class = MAEConfig
|
59 |
+
|
60 |
+
# Loss metrics
|
61 |
+
TOTAL_LOSS = "loss"
|
62 |
+
RECON_LOSS = "reconstruction_loss"
|
63 |
+
FOURIER_LOSS = "fourier_loss"
|
64 |
+
|
65 |
+
def __init__(self, config: MAEConfig):
|
66 |
+
super().__init__(config)
|
67 |
+
|
68 |
+
self.mask_ratio = config.mask_ratio
|
69 |
+
|
70 |
+
# Could use Hydra to instantiate instead
|
71 |
+
self.encoder = MAEEncoder(
|
72 |
+
vit_backbone=sincos_positional_encoding_vit(
|
73 |
+
vit_backbone=vit_small_patch16_256(global_pool="avg")
|
74 |
+
),
|
75 |
+
max_in_chans=11, # upper limit on number of input channels
|
76 |
+
channel_agnostic=True,
|
77 |
+
)
|
78 |
+
self.decoder = CAMAEDecoder(
|
79 |
+
depth=8,
|
80 |
+
embed_dim=512,
|
81 |
+
mlp_ratio=4,
|
82 |
+
norm_layer=nn.LayerNorm,
|
83 |
+
num_heads=16,
|
84 |
+
num_modalities=6,
|
85 |
+
qkv_bias=True,
|
86 |
+
tokens_per_modality=256,
|
87 |
+
)
|
88 |
+
self.input_norm = torch.nn.Sequential(
|
89 |
+
Normalizer(),
|
90 |
+
nn.InstanceNorm2d(None, affine=False, track_running_stats=False),
|
91 |
+
)
|
92 |
+
|
93 |
+
self.fourier_loss_weight = config.fourier_loss_weight
|
94 |
+
self.mask_fourier_loss = config.mask_fourier_loss
|
95 |
+
self.return_channelwise_embeddings = config.return_channelwise_embeddings
|
96 |
+
self.tokens_per_channel = 256 # hardcode the number of tokens per channel since we are patch16 crop 256
|
97 |
+
|
98 |
+
# loss stuff
|
99 |
+
self.loss = torch.nn.MSELoss(reduction="none")
|
100 |
+
|
101 |
+
self.fourier_loss = FourierLoss(num_multimodal_modalities=6)
|
102 |
+
if self.fourier_loss_weight > 0 and self.fourier_loss is None:
|
103 |
+
raise ValueError(
|
104 |
+
"FourierLoss weight is activated but no fourier_loss was defined in constructor"
|
105 |
+
)
|
106 |
+
elif self.fourier_loss_weight >= 1:
|
107 |
+
raise ValueError(
|
108 |
+
"FourierLoss weight is too large to do mixing factor, weight should be < 1"
|
109 |
+
)
|
110 |
+
|
111 |
+
self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0])
|
112 |
+
|
113 |
+
# projection layer between the encoder and decoder
|
114 |
+
self.encoder_decoder_proj = nn.Linear(
|
115 |
+
self.encoder.embed_dim, self.decoder.embed_dim, bias=True
|
116 |
+
)
|
117 |
+
|
118 |
+
self.decoder_pred = nn.Linear(
|
119 |
+
self.decoder.embed_dim,
|
120 |
+
self.patch_size**2
|
121 |
+
* (1 if self.encoder.channel_agnostic else self.in_chans),
|
122 |
+
bias=True,
|
123 |
+
) # linear layer from decoder embedding to input dims
|
124 |
+
|
125 |
+
# overwrite decoder pos embeddings based on encoder params
|
126 |
+
self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( # type: ignore[assignment]
|
127 |
+
self.decoder.embed_dim,
|
128 |
+
length=self.encoder.vit_backbone.patch_embed.grid_size[0],
|
129 |
+
use_class_token=self.encoder.vit_backbone.cls_token is not None,
|
130 |
+
num_modality=(
|
131 |
+
self.decoder.num_modalities if self.encoder.channel_agnostic else 1
|
132 |
+
),
|
133 |
+
)
|
134 |
+
|
135 |
+
if config.use_MAE_weight_init:
|
136 |
+
w = self.encoder.vit_backbone.patch_embed.proj.weight.data
|
137 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
138 |
+
|
139 |
+
torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02)
|
140 |
+
torch.nn.init.normal_(self.decoder.mask_token, std=0.02)
|
141 |
+
|
142 |
+
self.apply(self._MAE_init_weights)
|
143 |
+
|
144 |
+
def setup(self, stage: str) -> None:
|
145 |
+
super().setup(stage)
|
146 |
+
|
147 |
+
def _MAE_init_weights(self, m):
|
148 |
+
if isinstance(m, nn.Linear):
|
149 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
150 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
151 |
+
nn.init.constant_(m.bias, 0)
|
152 |
+
elif isinstance(m, nn.LayerNorm):
|
153 |
+
nn.init.constant_(m.bias, 0)
|
154 |
+
nn.init.constant_(m.weight, 1.0)
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def decode_to_reconstruction(
|
158 |
+
encoder_latent: torch.Tensor,
|
159 |
+
ind_restore: torch.Tensor,
|
160 |
+
proj: torch.nn.Module,
|
161 |
+
decoder: MAEDecoder | CAMAEDecoder,
|
162 |
+
pred: torch.nn.Module,
|
163 |
+
) -> torch.Tensor:
|
164 |
+
"""Feed forward the encoder latent through the decoders necessary projections and transformations."""
|
165 |
+
decoder_latent_projection = proj(
|
166 |
+
encoder_latent
|
167 |
+
) # projection from encoder.embed_dim to decoder.embed_dim
|
168 |
+
decoder_tokens = decoder.forward_masked(
|
169 |
+
decoder_latent_projection, ind_restore
|
170 |
+
) # decoder.embed_dim output
|
171 |
+
predicted_reconstruction = pred(
|
172 |
+
decoder_tokens
|
173 |
+
) # linear projection to input dim
|
174 |
+
return predicted_reconstruction[:, 1:, :] # drop class token
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None
|
178 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
179 |
+
imgs = self.input_norm(imgs)
|
180 |
+
latent, mask, ind_restore = self.encoder.forward_masked(
|
181 |
+
imgs, self.mask_ratio, constant_noise
|
182 |
+
) # encoder blocks
|
183 |
+
reconstruction = self.decode_to_reconstruction(
|
184 |
+
latent,
|
185 |
+
ind_restore,
|
186 |
+
self.encoder_decoder_proj,
|
187 |
+
self.decoder,
|
188 |
+
self.decoder_pred,
|
189 |
+
)
|
190 |
+
return latent, reconstruction, mask
|
191 |
+
|
192 |
+
def compute_MAE_loss(
|
193 |
+
self,
|
194 |
+
reconstruction: torch.Tensor,
|
195 |
+
img: torch.Tensor,
|
196 |
+
mask: torch.Tensor,
|
197 |
+
) -> Tuple[torch.Tensor, Dict[str, float]]:
|
198 |
+
"""Computes final loss and returns specific values of component losses for metric reporting."""
|
199 |
+
loss_dict = {}
|
200 |
+
img = self.input_norm(img)
|
201 |
+
target_flattened = flatten_images(
|
202 |
+
img,
|
203 |
+
patch_size=self.patch_size,
|
204 |
+
channel_agnostic=self.encoder.channel_agnostic,
|
205 |
+
)
|
206 |
+
|
207 |
+
loss: torch.Tensor = self.loss(
|
208 |
+
reconstruction, target_flattened
|
209 |
+
) # should be with MSE or MAE (L1) with reduction='none'
|
210 |
+
loss = loss.mean(
|
211 |
+
dim=-1
|
212 |
+
) # average over embedding dim -> mean loss per patch (N,L)
|
213 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches only
|
214 |
+
loss_dict[self.RECON_LOSS] = loss.item()
|
215 |
+
|
216 |
+
# compute fourier loss
|
217 |
+
if self.fourier_loss_weight > 0:
|
218 |
+
floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened)
|
219 |
+
if not self.mask_fourier_loss:
|
220 |
+
floss = floss.mean()
|
221 |
+
else:
|
222 |
+
floss = floss.mean(dim=-1)
|
223 |
+
floss = (floss * mask).sum() / mask.sum()
|
224 |
+
|
225 |
+
loss_dict[self.FOURIER_LOSS] = floss.item()
|
226 |
+
|
227 |
+
# here we use a mixing factor to keep the loss magnitude appropriate with fourier
|
228 |
+
if self.fourier_loss_weight > 0:
|
229 |
+
loss = (1 - self.fourier_loss_weight) * loss + (
|
230 |
+
self.fourier_loss_weight * floss
|
231 |
+
)
|
232 |
+
return loss, loss_dict
|
233 |
+
|
234 |
+
def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
|
235 |
+
img = batch["pixels"]
|
236 |
+
latent, reconstruction, mask = self(img.clone())
|
237 |
+
full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask)
|
238 |
+
return {
|
239 |
+
"loss": full_loss,
|
240 |
+
**loss_dict, # type: ignore[dict-item]
|
241 |
+
}
|
242 |
+
|
243 |
+
def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
|
244 |
+
return self.training_step(batch, batch_idx)
|
245 |
+
|
246 |
+
def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None:
|
247 |
+
self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr())
|
248 |
+
for key, value in outputs.items():
|
249 |
+
if key.endswith("loss"):
|
250 |
+
self.metrics[key].update(value)
|
251 |
+
|
252 |
+
def on_validation_batch_end( # type: ignore[override]
|
253 |
+
self,
|
254 |
+
outputs: TensorDict,
|
255 |
+
batch: TensorDict,
|
256 |
+
batch_idx: int,
|
257 |
+
dataloader_idx: int = 0,
|
258 |
+
) -> None:
|
259 |
+
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
260 |
+
|
261 |
+
def predict(self, imgs: torch.Tensor) -> torch.Tensor:
|
262 |
+
imgs = self.input_norm(imgs)
|
263 |
+
X = self.encoder.vit_backbone.forward_features(
|
264 |
+
imgs
|
265 |
+
) # 3d tensor N x num_tokens x dim
|
266 |
+
if self.return_channelwise_embeddings:
|
267 |
+
N, _, d = X.shape
|
268 |
+
num_channels = imgs.shape[1]
|
269 |
+
X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d)
|
270 |
+
pooled_segments = X_reshaped.mean(
|
271 |
+
dim=2
|
272 |
+
) # Resulting shape: (N, num_channels, d)
|
273 |
+
latent = pooled_segments.view(N, num_channels * d).contiguous()
|
274 |
+
else:
|
275 |
+
latent = X[:, 1:, :].mean(dim=1) # 1 + 256 * C tokens
|
276 |
+
return latent
|
277 |
+
|
278 |
+
def save_pretrained(self, save_directory: str, **kwargs):
|
279 |
+
filename = kwargs.pop("filename", "model.safetensors")
|
280 |
+
modelpath = f"{save_directory}/{filename}"
|
281 |
+
self.config.save_pretrained(save_directory)
|
282 |
+
torch.save({"state_dict": self.state_dict()}, modelpath)
|
283 |
+
|
284 |
+
@classmethod
|
285 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
286 |
+
filename = kwargs.pop("filename", "model.safetensors")
|
287 |
+
|
288 |
+
modelpath = f"{pretrained_model_name_or_path}/{filename}"
|
289 |
+
config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
290 |
+
state_dict = torch.load(modelpath, map_location="cpu")
|
291 |
+
model = cls(config, *model_args, **kwargs)
|
292 |
+
model.load_state_dict(state_dict["state_dict"])
|
293 |
+
return model
|
loss.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class FourierLoss(nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
use_l1_loss: bool = True,
|
10 |
+
num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
|
11 |
+
) -> None:
|
12 |
+
"""
|
13 |
+
Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
|
14 |
+
between the images / their radial histograms.
|
15 |
+
|
16 |
+
We will always set `reduction="none"` and enforce that the computation of any reductions from the
|
17 |
+
output of this loss be managed by the model under question.
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.loss = (
|
21 |
+
nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
|
22 |
+
)
|
23 |
+
self.num_modalities = num_multimodal_modalities
|
24 |
+
|
25 |
+
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
26 |
+
# input = reconstructed image, target = original image
|
27 |
+
# flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
|
28 |
+
flattened_images = len(input.shape) == len(target.shape) == 3
|
29 |
+
if flattened_images:
|
30 |
+
B, H_W, C = input.shape
|
31 |
+
H_W = H_W // self.num_modalities
|
32 |
+
four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
|
33 |
+
input = input.view(*four_d_shape)
|
34 |
+
target = target.view(*four_d_shape)
|
35 |
+
else:
|
36 |
+
B, C, h, w = input.shape
|
37 |
+
H_W = h * w
|
38 |
+
|
39 |
+
if len(input.shape) != len(target.shape) != 4:
|
40 |
+
raise ValueError(
|
41 |
+
f"Invalid input shape: got {input.shape} and {target.shape}."
|
42 |
+
)
|
43 |
+
|
44 |
+
fft_reconstructed = torch.fft.fft2(input)
|
45 |
+
fft_original = torch.fft.fft2(target)
|
46 |
+
|
47 |
+
magnitude_reconstructed = torch.abs(fft_reconstructed)
|
48 |
+
magnitude_original = torch.abs(fft_original)
|
49 |
+
|
50 |
+
loss_tensor: torch.Tensor = self.loss(
|
51 |
+
magnitude_reconstructed, magnitude_original
|
52 |
+
)
|
53 |
+
|
54 |
+
if (
|
55 |
+
flattened_images and not self.num_bins
|
56 |
+
): # then output loss should be reshaped
|
57 |
+
loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
|
58 |
+
|
59 |
+
return loss_tensor
|
mae_modules.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
from functools import partial
|
3 |
+
from typing import Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from timm.models.helpers import checkpoint_seq
|
8 |
+
from timm.models.vision_transformer import Block, Mlp, VisionTransformer
|
9 |
+
|
10 |
+
from masking import transformer_random_masking
|
11 |
+
from vit import channel_agnostic_vit
|
12 |
+
|
13 |
+
# If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
|
14 |
+
# leverage the flattening and unflattening utilities as needed from mae_utils.py.
|
15 |
+
# Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
|
16 |
+
# As described in the paper, images are self-standardized at the start.
|
17 |
+
|
18 |
+
|
19 |
+
class SelfStandardize(nn.Module):
|
20 |
+
def __init__(self) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.self_standardize = nn.LazyInstanceNorm2d(
|
23 |
+
affine=False, track_running_stats=False
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, pixels: torch.Tensor) -> torch.Tensor:
|
27 |
+
x = pixels.float() / 255.0
|
28 |
+
return self.self_standardize(x)
|
29 |
+
|
30 |
+
|
31 |
+
class MAEEncoder(nn.Module):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
vit_backbone: VisionTransformer,
|
35 |
+
max_in_chans: int = 6,
|
36 |
+
channel_agnostic: bool = False,
|
37 |
+
) -> None:
|
38 |
+
super().__init__()
|
39 |
+
if channel_agnostic:
|
40 |
+
self.vit_backbone = channel_agnostic_vit(
|
41 |
+
vit_backbone, max_in_chans=max_in_chans
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
self.vit_backbone = vit_backbone
|
45 |
+
self.max_in_chans = max_in_chans
|
46 |
+
self.channel_agnostic = channel_agnostic
|
47 |
+
|
48 |
+
@property
|
49 |
+
def embed_dim(self) -> int:
|
50 |
+
return int(self.vit_backbone.embed_dim)
|
51 |
+
|
52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
53 |
+
x = self.vit_backbone.forward_features(x)
|
54 |
+
x = self.vit_backbone.forward_head(x)
|
55 |
+
return x # type: ignore[no-any-return]
|
56 |
+
|
57 |
+
def forward_masked(
|
58 |
+
self,
|
59 |
+
x: torch.Tensor,
|
60 |
+
mask_ratio: float,
|
61 |
+
constant_noise: Union[torch.Tensor, None] = None,
|
62 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
63 |
+
x = self.vit_backbone.patch_embed(x)
|
64 |
+
x = self.vit_backbone._pos_embed(x) # adds class token
|
65 |
+
x_ = x[:, 1:, :] # no class token
|
66 |
+
x_, mask, ind_restore = transformer_random_masking(
|
67 |
+
x_, mask_ratio, constant_noise
|
68 |
+
)
|
69 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
70 |
+
x = self.vit_backbone.norm_pre(x)
|
71 |
+
|
72 |
+
if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
|
73 |
+
x = checkpoint_seq(self.vit_backbone.blocks, x)
|
74 |
+
else:
|
75 |
+
x = self.vit_backbone.blocks(x)
|
76 |
+
x = self.vit_backbone.norm(x)
|
77 |
+
return x, mask, ind_restore
|
78 |
+
|
79 |
+
|
80 |
+
class MAEDecoder(nn.Module):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
embed_dim: int = 512,
|
84 |
+
depth: int = 8,
|
85 |
+
num_heads: int = 16,
|
86 |
+
mlp_ratio: float = 4,
|
87 |
+
qkv_bias: bool = True,
|
88 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
|
89 |
+
) -> None:
|
90 |
+
super().__init__()
|
91 |
+
self.embed_dim = embed_dim
|
92 |
+
self.pos_embeddings = None # to be overwritten by MAE class
|
93 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
94 |
+
self.blocks = nn.Sequential(
|
95 |
+
*[
|
96 |
+
Block(
|
97 |
+
embed_dim,
|
98 |
+
num_heads,
|
99 |
+
mlp_ratio,
|
100 |
+
qkv_bias=qkv_bias,
|
101 |
+
norm_layer=norm_layer,
|
102 |
+
)
|
103 |
+
for i in range(depth)
|
104 |
+
]
|
105 |
+
)
|
106 |
+
self.norm = norm_layer(embed_dim)
|
107 |
+
|
108 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
109 |
+
x = x + self.pos_embeddings
|
110 |
+
x = self.blocks(x)
|
111 |
+
x = self.norm(x)
|
112 |
+
return x # type: ignore[no-any-return]
|
113 |
+
|
114 |
+
def forward_masked(
|
115 |
+
self, x: torch.Tensor, ind_restore: torch.Tensor
|
116 |
+
) -> torch.Tensor:
|
117 |
+
mask_tokens = self.mask_token.repeat(
|
118 |
+
x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
|
119 |
+
)
|
120 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
|
121 |
+
x_ = torch.gather(
|
122 |
+
x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
123 |
+
) # unshuffle
|
124 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
125 |
+
|
126 |
+
x = x + self.pos_embeddings
|
127 |
+
x = self.blocks(x)
|
128 |
+
x = self.norm(x)
|
129 |
+
return x # type: ignore[no-any-return]
|
130 |
+
|
131 |
+
|
132 |
+
class CrossAttention(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
self.num_heads = num_heads
|
138 |
+
head_dim = embed_dim // num_heads
|
139 |
+
self.scale = head_dim**-0.5
|
140 |
+
|
141 |
+
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
142 |
+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
143 |
+
|
144 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
145 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
146 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
147 |
+
|
148 |
+
def forward(self, x, context):
|
149 |
+
B, N, C = x.shape
|
150 |
+
_, M, _ = context.shape
|
151 |
+
|
152 |
+
q = (
|
153 |
+
self.q(x)
|
154 |
+
.reshape(B, N, self.num_heads, C // self.num_heads)
|
155 |
+
.permute(0, 2, 1, 3)
|
156 |
+
)
|
157 |
+
kv = (
|
158 |
+
self.kv(context)
|
159 |
+
.reshape(B, M, 2, self.num_heads, C // self.num_heads)
|
160 |
+
.permute(2, 0, 3, 1, 4)
|
161 |
+
)
|
162 |
+
k, v = kv[0], kv[1]
|
163 |
+
|
164 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
165 |
+
attn = attn.softmax(dim=-1)
|
166 |
+
attn = self.attn_drop(attn)
|
167 |
+
|
168 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
169 |
+
x = self.proj(x)
|
170 |
+
x = self.proj_drop(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class CAMAEDecoder(nn.Module):
|
175 |
+
def __init__(
|
176 |
+
self,
|
177 |
+
num_modalities: int = 6,
|
178 |
+
tokens_per_modality: int = 256,
|
179 |
+
embed_dim: int = 256,
|
180 |
+
depth: int = 2,
|
181 |
+
num_heads: int = 16,
|
182 |
+
mlp_ratio: float = 4,
|
183 |
+
qkv_bias: bool = True,
|
184 |
+
norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
|
185 |
+
) -> None:
|
186 |
+
super().__init__()
|
187 |
+
self.num_modalities = num_modalities
|
188 |
+
self.tokens_per_modality = tokens_per_modality
|
189 |
+
self.embed_dim = embed_dim
|
190 |
+
self.pos_embeddings = None # to be overwritten by MAE class
|
191 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
192 |
+
self.placeholder = nn.Parameter(
|
193 |
+
torch.zeros(1, 1, embed_dim), requires_grad=False
|
194 |
+
)
|
195 |
+
self.modality_tokens = nn.ParameterList(
|
196 |
+
[
|
197 |
+
nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
198 |
+
for modality in range(self.num_modalities)
|
199 |
+
]
|
200 |
+
)
|
201 |
+
|
202 |
+
self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
|
203 |
+
self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
|
204 |
+
|
205 |
+
self.decoders = nn.ModuleList(
|
206 |
+
[
|
207 |
+
nn.Sequential(
|
208 |
+
*[
|
209 |
+
Block(
|
210 |
+
embed_dim,
|
211 |
+
num_heads,
|
212 |
+
mlp_ratio,
|
213 |
+
qkv_bias=qkv_bias,
|
214 |
+
norm_layer=norm_layer,
|
215 |
+
)
|
216 |
+
for i in range(depth)
|
217 |
+
]
|
218 |
+
)
|
219 |
+
for modality in range(self.num_modalities)
|
220 |
+
]
|
221 |
+
)
|
222 |
+
# self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm
|
223 |
+
self.context_norm = norm_layer(embed_dim)
|
224 |
+
self.query_norm = norm_layer(embed_dim)
|
225 |
+
self.out_norm = norm_layer(embed_dim)
|
226 |
+
|
227 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
228 |
+
x_m_s = []
|
229 |
+
|
230 |
+
modality_tokens_concat = torch.cat(
|
231 |
+
[
|
232 |
+
self.placeholder,
|
233 |
+
] # placeholder for class token
|
234 |
+
+ [
|
235 |
+
m_t.repeat(1, self.tokens_per_modality, 1)
|
236 |
+
for m_t in self.modality_tokens
|
237 |
+
],
|
238 |
+
dim=1,
|
239 |
+
)
|
240 |
+
|
241 |
+
x = (
|
242 |
+
x + self.pos_embeddings + modality_tokens_concat
|
243 |
+
) # add pos and tiled modality tokens
|
244 |
+
x_ = x[:, 1:, :] # no class token
|
245 |
+
for m, decoder in enumerate(
|
246 |
+
self.decoders
|
247 |
+
): # iterate through modalities and decoders
|
248 |
+
x_m = x_[
|
249 |
+
:, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
|
250 |
+
]
|
251 |
+
x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
|
252 |
+
x_m = x_m + self.mlp(self.out_norm(x_m))
|
253 |
+
x_m = decoder(x_m)
|
254 |
+
x_m_s.append(x_m)
|
255 |
+
x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens
|
256 |
+
# x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm
|
257 |
+
x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token
|
258 |
+
|
259 |
+
return x_m_s
|
260 |
+
|
261 |
+
def forward_masked(
|
262 |
+
self, x: torch.Tensor, ind_restore: torch.Tensor
|
263 |
+
) -> torch.Tensor:
|
264 |
+
mask_tokens = self.mask_token.repeat(
|
265 |
+
x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
|
266 |
+
)
|
267 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
|
268 |
+
x_ = torch.gather(
|
269 |
+
x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
270 |
+
) # unshuffle
|
271 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
|
272 |
+
x = self.forward(x)
|
273 |
+
return x
|
mae_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def flatten_images(
|
8 |
+
img: torch.Tensor, patch_size: int, channel_agnostic: bool = False
|
9 |
+
) -> torch.Tensor:
|
10 |
+
"""
|
11 |
+
Flattens 2D images into tokens with the same pixel values
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
----------
|
15 |
+
img : input image tensor (N, C, H, W)
|
16 |
+
|
17 |
+
Returns
|
18 |
+
-------
|
19 |
+
flattened_img: flattened image tensor (N, L, patch_size**2 * C)
|
20 |
+
"""
|
21 |
+
|
22 |
+
if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
|
23 |
+
raise ValueError("image H must equal image W and be divisible by patch_size")
|
24 |
+
in_chans = img.shape[1]
|
25 |
+
|
26 |
+
h = w = int(img.shape[2] // patch_size)
|
27 |
+
x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
|
28 |
+
|
29 |
+
if channel_agnostic:
|
30 |
+
x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ
|
31 |
+
x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
|
32 |
+
else:
|
33 |
+
x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC
|
34 |
+
x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
def unflatten_tokens(
|
39 |
+
tokens: torch.Tensor,
|
40 |
+
patch_size: int,
|
41 |
+
num_modalities: int = 1,
|
42 |
+
channel_agnostic: bool = False,
|
43 |
+
) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
|
46 |
+
|
47 |
+
Parameters
|
48 |
+
----------
|
49 |
+
tokens : input token tensor (N,L,patch_size**2 * C)
|
50 |
+
|
51 |
+
Returns
|
52 |
+
-------
|
53 |
+
img: image tensor (N,C,H,W)
|
54 |
+
"""
|
55 |
+
if num_modalities > 1 and not channel_agnostic:
|
56 |
+
raise ValueError("Multiple modalities requires channel agnostic unflattening.")
|
57 |
+
|
58 |
+
h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
|
59 |
+
if h * w != (tokens.shape[1] // num_modalities):
|
60 |
+
raise ValueError("sqrt of number of tokens not integer")
|
61 |
+
|
62 |
+
if channel_agnostic:
|
63 |
+
x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
|
64 |
+
x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ
|
65 |
+
else:
|
66 |
+
x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
|
67 |
+
x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ
|
68 |
+
img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
|
69 |
+
|
70 |
+
return img
|
masking.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def transformer_random_masking(
|
8 |
+
x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
|
9 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
10 |
+
"""
|
11 |
+
Random mask patches per sample
|
12 |
+
|
13 |
+
Parameters
|
14 |
+
----------
|
15 |
+
x : token tensor (N, L, D)
|
16 |
+
mask_ratio: float - ratio of image to mask
|
17 |
+
constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
|
18 |
+
|
19 |
+
Returns
|
20 |
+
-------
|
21 |
+
x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
|
22 |
+
mask : binary mask indicated masked tokens (1 where masked) (N, L)
|
23 |
+
ind_restore : locations of masked tokens, needed for decoder
|
24 |
+
"""
|
25 |
+
|
26 |
+
N, L, D = x.shape # batch, length, dim
|
27 |
+
len_keep = int(L * (1 - mask_ratio))
|
28 |
+
|
29 |
+
# use random noise to generate batch based random masks
|
30 |
+
if constant_noise is not None:
|
31 |
+
noise = constant_noise
|
32 |
+
else:
|
33 |
+
noise = torch.rand(N, L, device=x.device)
|
34 |
+
|
35 |
+
shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
|
36 |
+
ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
|
37 |
+
|
38 |
+
# get masked input
|
39 |
+
tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
|
40 |
+
x_masked = torch.gather(
|
41 |
+
x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
|
42 |
+
)
|
43 |
+
|
44 |
+
# get binary mask used for loss masking: 0 is keep, 1 is remove
|
45 |
+
mask = torch.ones([N, L], device=x.device)
|
46 |
+
mask[:, :len_keep] = 0
|
47 |
+
mask = torch.gather(
|
48 |
+
mask, dim=1, index=ind_restore
|
49 |
+
) # unshuffle to get the binary mask
|
50 |
+
|
51 |
+
return x_masked, mask, ind_restore
|
normalizer.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class Normalizer(torch.nn.Module):
|
5 |
+
def forward(self, pixels: torch.Tensor) -> torch.Tensor:
|
6 |
+
pixels = pixels.float()
|
7 |
+
return pixels / 255.0
|
pyproject.toml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools >= 61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "maes_microscopy_project"
|
7 |
+
version = "0.1.0"
|
8 |
+
authors = [
|
9 |
+
{name = "kian-kd", email = "[email protected]"},
|
10 |
+
{name = "Laksh47", email = "[email protected]"},
|
11 |
+
]
|
12 |
+
requires-python = ">=3.10.4"
|
13 |
+
|
14 |
+
dependencies = [
|
15 |
+
"huggingface-hub",
|
16 |
+
"timm",
|
17 |
+
"torch>=2.3",
|
18 |
+
"torchmetrics",
|
19 |
+
"torchvision",
|
20 |
+
"tqdm",
|
21 |
+
"transformers",
|
22 |
+
"xformers",
|
23 |
+
"zarr",
|
24 |
+
"pytorch-lightning>=2.1",
|
25 |
+
"matplotlib",
|
26 |
+
"scikit-image",
|
27 |
+
"ipykernel",
|
28 |
+
"isort",
|
29 |
+
"ruff",
|
30 |
+
"pytest",
|
31 |
+
]
|
32 |
+
|
33 |
+
[tool.setuptools]
|
34 |
+
py-modules = []
|
sample/AA41_s1_1.jp2
ADDED
sample/AA41_s1_2.jp2
ADDED
sample/AA41_s1_3.jp2
ADDED
sample/AA41_s1_4.jp2
ADDED
sample/AA41_s1_5.jp2
ADDED
sample/AA41_s1_6.jp2
ADDED
test_huggingface_mae.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytest
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from huggingface_mae import MAEModel
|
5 |
+
|
6 |
+
huggingface_phenombeta_model_dir = "."
|
7 |
+
# huggingface_modelpath = "recursionpharma/test-pb-model"
|
8 |
+
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def huggingface_model():
|
12 |
+
# Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
|
13 |
+
# huggingface-cli download recursionpharma/test-pb-model --local-dir=.
|
14 |
+
huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
|
15 |
+
huggingface_model.eval()
|
16 |
+
return huggingface_model
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.mark.parametrize("C", [1, 4, 6, 11])
|
20 |
+
@pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
|
21 |
+
def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
|
22 |
+
example_input_array = torch.randint(
|
23 |
+
low=0,
|
24 |
+
high=255,
|
25 |
+
size=(2, C, 256, 256),
|
26 |
+
dtype=torch.uint8,
|
27 |
+
device=huggingface_model.device,
|
28 |
+
)
|
29 |
+
huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
|
30 |
+
embeddings = huggingface_model.predict(example_input_array)
|
31 |
+
expected_output_dim = 384 * C if return_channelwise_embeddings else 384
|
32 |
+
assert embeddings.shape == (2, expected_output_dim)
|
vit.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
import timm.models.vision_transformer as vit
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def generate_2d_sincos_pos_embeddings(
|
7 |
+
embedding_dim: int,
|
8 |
+
length: int,
|
9 |
+
scale: float = 10000.0,
|
10 |
+
use_class_token: bool = True,
|
11 |
+
num_modality: int = 1,
|
12 |
+
) -> torch.nn.Parameter:
|
13 |
+
"""
|
14 |
+
Generate 2Dimensional sin/cosine positional embeddings
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
embedding_dim : int
|
19 |
+
embedding dimension used in vit
|
20 |
+
length : int
|
21 |
+
number of tokens along height or width of image after patching (assuming square)
|
22 |
+
scale : float
|
23 |
+
scale for sin/cos functions
|
24 |
+
use_class_token : bool
|
25 |
+
True - add zero vector to be added to class_token, False - no vector added
|
26 |
+
num_modality: number of modalities. If 0, a single modality is assumed.
|
27 |
+
Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
|
28 |
+
|
29 |
+
Returns
|
30 |
+
-------
|
31 |
+
positional_encoding : torch.Tensor
|
32 |
+
positional encoding to add to vit patch encodings
|
33 |
+
[num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
|
34 |
+
(w/ or w/o cls_token)
|
35 |
+
"""
|
36 |
+
|
37 |
+
linear_positions = torch.arange(length, dtype=torch.float32)
|
38 |
+
height_mesh, width_mesh = torch.meshgrid(
|
39 |
+
linear_positions, linear_positions, indexing="ij"
|
40 |
+
)
|
41 |
+
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
|
42 |
+
positional_weights = (
|
43 |
+
torch.arange(positional_dim, dtype=torch.float32) / positional_dim
|
44 |
+
)
|
45 |
+
positional_weights = 1.0 / (scale**positional_weights)
|
46 |
+
|
47 |
+
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
|
48 |
+
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
|
49 |
+
|
50 |
+
positional_encoding = torch.cat(
|
51 |
+
[
|
52 |
+
torch.sin(height_weights),
|
53 |
+
torch.cos(height_weights),
|
54 |
+
torch.sin(width_weights),
|
55 |
+
torch.cos(width_weights),
|
56 |
+
],
|
57 |
+
dim=1,
|
58 |
+
)[None, :, :]
|
59 |
+
|
60 |
+
# repeat positional encoding for multiple channel modalities
|
61 |
+
positional_encoding = positional_encoding.repeat(1, num_modality, 1)
|
62 |
+
|
63 |
+
if use_class_token:
|
64 |
+
class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
|
65 |
+
positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
|
66 |
+
|
67 |
+
positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
|
68 |
+
|
69 |
+
return positional_encoding
|
70 |
+
|
71 |
+
|
72 |
+
class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
img_size: int,
|
76 |
+
patch_size: int,
|
77 |
+
embed_dim: int,
|
78 |
+
bias: bool = True,
|
79 |
+
) -> None:
|
80 |
+
super().__init__(
|
81 |
+
img_size=img_size,
|
82 |
+
patch_size=patch_size,
|
83 |
+
in_chans=1, # in_chans is used by self.proj, which we override anyway
|
84 |
+
embed_dim=embed_dim,
|
85 |
+
norm_layer=None,
|
86 |
+
flatten=False,
|
87 |
+
bias=bias,
|
88 |
+
)
|
89 |
+
# channel-agnostic MAE has a single projection for all chans
|
90 |
+
self.proj = torch.nn.Conv2d(
|
91 |
+
1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
92 |
+
)
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
in_chans = x.shape[1]
|
96 |
+
x = torch.stack(
|
97 |
+
[self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
|
98 |
+
) # single project for all chans
|
99 |
+
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
|
100 |
+
return x
|
101 |
+
|
102 |
+
|
103 |
+
class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
|
104 |
+
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
|
105 |
+
# rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
|
106 |
+
to_cat = []
|
107 |
+
if self.cls_token is not None:
|
108 |
+
to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
|
109 |
+
|
110 |
+
# TODO: upgrade timm to get access to register tokens
|
111 |
+
# if self.vit_backbone.reg_token is not None:
|
112 |
+
# to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
|
113 |
+
|
114 |
+
# MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
|
115 |
+
# this supports having CA-MAEs actually be channel-agnostic at inference time
|
116 |
+
if self.no_embed_class:
|
117 |
+
x = x + self.pos_embed[:, : x.shape[1]]
|
118 |
+
if to_cat:
|
119 |
+
x = torch.cat(to_cat + [x], dim=1)
|
120 |
+
else:
|
121 |
+
if to_cat:
|
122 |
+
x = torch.cat(to_cat + [x], dim=1)
|
123 |
+
x = x + self.pos_embed[:, : x.shape[1]]
|
124 |
+
return self.pos_drop(x) # type: ignore[no-any-return]
|
125 |
+
|
126 |
+
|
127 |
+
def channel_agnostic_vit(
|
128 |
+
vit_backbone: vit.VisionTransformer, max_in_chans: int
|
129 |
+
) -> vit.VisionTransformer:
|
130 |
+
# replace patch embedding with channel-agnostic version
|
131 |
+
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
|
132 |
+
img_size=vit_backbone.patch_embed.img_size[0],
|
133 |
+
patch_size=vit_backbone.patch_embed.patch_size[0],
|
134 |
+
embed_dim=vit_backbone.embed_dim,
|
135 |
+
)
|
136 |
+
|
137 |
+
# replace positional embedding with channel-agnostic version
|
138 |
+
vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
|
139 |
+
embedding_dim=vit_backbone.embed_dim,
|
140 |
+
length=vit_backbone.patch_embed.grid_size[0],
|
141 |
+
use_class_token=vit_backbone.cls_token is not None,
|
142 |
+
num_modality=max_in_chans,
|
143 |
+
)
|
144 |
+
|
145 |
+
# change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
|
146 |
+
vit_backbone.__class__ = ChannelAgnosticViT
|
147 |
+
return vit_backbone
|
148 |
+
|
149 |
+
|
150 |
+
def sincos_positional_encoding_vit(
|
151 |
+
vit_backbone: vit.VisionTransformer, scale: float = 10000.0
|
152 |
+
) -> vit.VisionTransformer:
|
153 |
+
"""Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
|
154 |
+
|
155 |
+
Parameters
|
156 |
+
----------
|
157 |
+
vit_backbone : timm.models.vision_transformer.VisionTransformer
|
158 |
+
the constructed vision transformer from timm
|
159 |
+
scale : float (default 10000.0)
|
160 |
+
hyperparameter for sincos positional embeddings, recommend keeping at 10,000
|
161 |
+
|
162 |
+
Returns
|
163 |
+
-------
|
164 |
+
timm.models.vision_transformer.VisionTransformer
|
165 |
+
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
|
166 |
+
"""
|
167 |
+
# length: number of tokens along height or width of image after patching (assuming square)
|
168 |
+
length = (
|
169 |
+
vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
|
170 |
+
)
|
171 |
+
pos_embeddings = generate_2d_sincos_pos_embeddings(
|
172 |
+
vit_backbone.embed_dim,
|
173 |
+
length=length,
|
174 |
+
scale=scale,
|
175 |
+
use_class_token=vit_backbone.cls_token is not None,
|
176 |
+
)
|
177 |
+
# note, if the model had weight_init == 'skip', this might get overwritten
|
178 |
+
vit_backbone.pos_embed = pos_embeddings
|
179 |
+
return vit_backbone
|
180 |
+
|
181 |
+
|
182 |
+
def vit_small_patch16_256(**kwargs):
|
183 |
+
default_kwargs = dict(
|
184 |
+
img_size=256,
|
185 |
+
in_chans=6,
|
186 |
+
num_classes=0,
|
187 |
+
fc_norm=None,
|
188 |
+
class_token=True,
|
189 |
+
drop_path_rate=0.1,
|
190 |
+
init_values=0.0001,
|
191 |
+
block_fn=vit.ParallelScalingBlock,
|
192 |
+
qkv_bias=False,
|
193 |
+
qk_norm=True,
|
194 |
+
)
|
195 |
+
for k, v in kwargs.items():
|
196 |
+
default_kwargs[k] = v
|
197 |
+
return vit.vit_small_patch16_224(**default_kwargs)
|
198 |
+
|
199 |
+
|
200 |
+
def vit_small_patch32_512(**kwargs):
|
201 |
+
default_kwargs = dict(
|
202 |
+
img_size=512,
|
203 |
+
in_chans=6,
|
204 |
+
num_classes=0,
|
205 |
+
fc_norm=None,
|
206 |
+
class_token=True,
|
207 |
+
drop_path_rate=0.1,
|
208 |
+
init_values=0.0001,
|
209 |
+
block_fn=vit.ParallelScalingBlock,
|
210 |
+
qkv_bias=False,
|
211 |
+
qk_norm=True,
|
212 |
+
)
|
213 |
+
for k, v in kwargs.items():
|
214 |
+
default_kwargs[k] = v
|
215 |
+
return vit.vit_small_patch32_384(**default_kwargs)
|
216 |
+
|
217 |
+
|
218 |
+
def vit_base_patch8_256(**kwargs):
|
219 |
+
default_kwargs = dict(
|
220 |
+
img_size=256,
|
221 |
+
in_chans=6,
|
222 |
+
num_classes=0,
|
223 |
+
fc_norm=None,
|
224 |
+
class_token=True,
|
225 |
+
drop_path_rate=0.1,
|
226 |
+
init_values=0.0001,
|
227 |
+
block_fn=vit.ParallelScalingBlock,
|
228 |
+
qkv_bias=False,
|
229 |
+
qk_norm=True,
|
230 |
+
)
|
231 |
+
for k, v in kwargs.items():
|
232 |
+
default_kwargs[k] = v
|
233 |
+
return vit.vit_base_patch8_224(**default_kwargs)
|
234 |
+
|
235 |
+
|
236 |
+
def vit_base_patch16_256(**kwargs):
|
237 |
+
default_kwargs = dict(
|
238 |
+
img_size=256,
|
239 |
+
in_chans=6,
|
240 |
+
num_classes=0,
|
241 |
+
fc_norm=None,
|
242 |
+
class_token=True,
|
243 |
+
drop_path_rate=0.1,
|
244 |
+
init_values=0.0001,
|
245 |
+
block_fn=vit.ParallelScalingBlock,
|
246 |
+
qkv_bias=False,
|
247 |
+
qk_norm=True,
|
248 |
+
)
|
249 |
+
for k, v in kwargs.items():
|
250 |
+
default_kwargs[k] = v
|
251 |
+
return vit.vit_base_patch16_224(**default_kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def vit_base_patch32_512(**kwargs):
|
255 |
+
default_kwargs = dict(
|
256 |
+
img_size=512,
|
257 |
+
in_chans=6,
|
258 |
+
num_classes=0,
|
259 |
+
fc_norm=None,
|
260 |
+
class_token=True,
|
261 |
+
drop_path_rate=0.1,
|
262 |
+
init_values=0.0001,
|
263 |
+
block_fn=vit.ParallelScalingBlock,
|
264 |
+
qkv_bias=False,
|
265 |
+
qk_norm=True,
|
266 |
+
)
|
267 |
+
for k, v in kwargs.items():
|
268 |
+
default_kwargs[k] = v
|
269 |
+
return vit.vit_base_patch32_384(**default_kwargs)
|
270 |
+
|
271 |
+
|
272 |
+
def vit_large_patch8_256(**kwargs):
|
273 |
+
default_kwargs = dict(
|
274 |
+
img_size=256,
|
275 |
+
in_chans=6,
|
276 |
+
num_classes=0,
|
277 |
+
fc_norm=None,
|
278 |
+
class_token=True,
|
279 |
+
patch_size=8,
|
280 |
+
embed_dim=1024,
|
281 |
+
depth=24,
|
282 |
+
num_heads=16,
|
283 |
+
drop_path_rate=0.3,
|
284 |
+
init_values=0.0001,
|
285 |
+
block_fn=vit.ParallelScalingBlock,
|
286 |
+
qkv_bias=False,
|
287 |
+
qk_norm=True,
|
288 |
+
)
|
289 |
+
for k, v in kwargs.items():
|
290 |
+
default_kwargs[k] = v
|
291 |
+
return vit.VisionTransformer(**default_kwargs)
|
292 |
+
|
293 |
+
|
294 |
+
def vit_large_patch16_256(**kwargs):
|
295 |
+
default_kwargs = dict(
|
296 |
+
img_size=256,
|
297 |
+
in_chans=6,
|
298 |
+
num_classes=0,
|
299 |
+
fc_norm=None,
|
300 |
+
class_token=True,
|
301 |
+
drop_path_rate=0.3,
|
302 |
+
init_values=0.0001,
|
303 |
+
block_fn=vit.ParallelScalingBlock,
|
304 |
+
qkv_bias=False,
|
305 |
+
qk_norm=True,
|
306 |
+
)
|
307 |
+
for k, v in kwargs.items():
|
308 |
+
default_kwargs[k] = v
|
309 |
+
return vit.vit_large_patch16_384(**default_kwargs)
|
vit_encoder.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import timm.models.vision_transformer as vit
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]:
|
9 |
+
"""This returns the prepped imagenet encoders from timm, not bad for microscopy data."""
|
10 |
+
vit_backbones = [
|
11 |
+
_make_vit(vit.vit_small_patch16_384),
|
12 |
+
_make_vit(vit.vit_base_patch16_384),
|
13 |
+
_make_vit(vit.vit_base_patch8_224),
|
14 |
+
_make_vit(vit.vit_large_patch16_384),
|
15 |
+
]
|
16 |
+
model_names = [
|
17 |
+
"vit_small_patch16_384",
|
18 |
+
"vit_base_patch16_384",
|
19 |
+
"vit_base_patch8_224",
|
20 |
+
"vit_large_patch16_384",
|
21 |
+
]
|
22 |
+
imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones))
|
23 |
+
return {name: model for name, model in zip(model_names, imagenet_encoders)}
|
24 |
+
|
25 |
+
|
26 |
+
def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule:
|
27 |
+
dummy_input = torch.testing.make_tensor(
|
28 |
+
(2, 6, 256, 256),
|
29 |
+
low=0,
|
30 |
+
high=255,
|
31 |
+
dtype=torch.uint8,
|
32 |
+
device=torch.device("cpu"),
|
33 |
+
)
|
34 |
+
encoder = torch.nn.Sequential(
|
35 |
+
Normalizer(),
|
36 |
+
torch.nn.LazyInstanceNorm2d(
|
37 |
+
affine=False, track_running_stats=False
|
38 |
+
), # this module performs self-standardization, very important
|
39 |
+
vit_backbone,
|
40 |
+
).to(device="cpu")
|
41 |
+
_ = encoder(dummy_input) # get those lazy modules built
|
42 |
+
return torch.jit.freeze(torch.jit.script(encoder.eval()))
|
43 |
+
|
44 |
+
|
45 |
+
def _make_vit(constructor):
|
46 |
+
return constructor(
|
47 |
+
pretrained=True, # download imagenet weights
|
48 |
+
img_size=256, # 256x256 crops
|
49 |
+
in_chans=6, # we expect 6-channel microscopy images
|
50 |
+
num_classes=0,
|
51 |
+
fc_norm=None,
|
52 |
+
class_token=True,
|
53 |
+
global_pool="avg", # minimal perf diff btwn "cls" and "avg"
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class Normalizer(torch.nn.Module):
|
58 |
+
def forward(self, pixels: torch.Tensor) -> torch.Tensor:
|
59 |
+
pixels = pixels.float()
|
60 |
+
pixels /= 255.0
|
61 |
+
return pixels
|