recursionaut commited on
Commit
6ded986
1 Parent(s): c1a2b2a

testing files upload (#7)

Browse files

- testing files upload (f489a598b0d6a46d9a99c819210b220269d9b29b)

.gitignore ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # model artifacts
30
+ *.pickle
31
+ *.ckpt
32
+ *.safetensors
LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
MODELCARD.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Phenom CA-MAE-S/16
7
+
8
+ Channel-agnostic image encoding model designed for microscopy image featurization.
9
+ The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
17
+ 1. RxRx3
18
+ 2. JUMP-CP overexpression
19
+ 3. JUMP-CP gene-knockouts
20
+
21
+ - **Developed, funded, and shared by:** Recursion
22
+ - **Model type:** Vision transformer CA-MAE
23
+ - **Image modality:** Optimized for microscopy images from the CellPainting assay
24
+ - **License:**
25
+
26
+
27
+ ### Model Sources
28
+
29
+ - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
30
+ - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
31
+
32
+
33
+ ## Uses
34
+
35
+ NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
36
+
37
+ 1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
38
+ 2. Transform all the embeddings with that PCA kernel,
39
+ 3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
40
+
41
+ ### Direct Use
42
+
43
+ - Create biologically useful embeddings of microscopy images
44
+ - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
45
+ - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
46
+
47
+ ### Downstream Use
48
+
49
+ - A determined ML expert could fine-tune the encoder for downstream tasks such as classification
50
+
51
+ ### Out-of-Scope Use
52
+
53
+ - Unlikely to be especially performant on brightfield microscopy images
54
+ - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
55
+
56
+ ## Bias, Risks, and Limitations
57
+
58
+ - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
59
+
60
+ ## How to Get Started with the Model
61
+
62
+ You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
63
+
64
+ ```python
65
+ import pytest
66
+ import torch
67
+
68
+ from huggingface_mae import MAEModel
69
+
70
+ huggingface_phenombeta_model_dir = "."
71
+ # huggingface_modelpath = "recursionpharma/test-pb-model"
72
+
73
+
74
+ @pytest.fixture
75
+ def huggingface_model():
76
+ # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
77
+ # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
78
+ huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
79
+ huggingface_model.eval()
80
+ return huggingface_model
81
+
82
+
83
+ @pytest.mark.parametrize("C", [1, 4, 6, 11])
84
+ @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
85
+ def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
86
+ example_input_array = torch.randint(
87
+ low=0,
88
+ high=255,
89
+ size=(2, C, 256, 256),
90
+ dtype=torch.uint8,
91
+ device=huggingface_model.device,
92
+ )
93
+ huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
94
+ embeddings = huggingface_model.predict(example_input_array)
95
+ expected_output_dim = 384 * C if return_channelwise_embeddings else 384
96
+ assert embeddings.shape == (2, expected_output_dim)
97
+ ```
98
+
99
+
100
+ ## Training, evaluation and testing details
101
+
102
+ See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above.
103
+
104
+
105
+ ## Environmental Impact
106
+
107
+ - **Hardware Type:** Nvidia H100 Hopper nodes
108
+ - **Hours used:** 400
109
+ - **Cloud Provider:** private cloud
110
+ - **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal)
111
+
112
+ **BibTeX:**
113
+
114
+ ```TeX
115
+ @inproceedings{kraus2024masked,
116
+ title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
117
+ author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
118
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
119
+ pages={11757--11768},
120
+ year={2024}
121
+ }
122
+ ```
123
+
124
+ ## Model Card Contact
125
+
126
+ - Kian Kenyon-Dean: [email protected]
127
+ - Oren Kraus: [email protected]
128
+ - Or, email: [email protected]
README.md CHANGED
@@ -1,128 +1,42 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
 
 
 
5
 
6
- # Model Card for Phenom CA-MAE-S/16
7
 
8
- Channel-agnostic image encoding model designed for microscopy image featurization.
9
- The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
10
 
 
 
11
 
12
- ## Model Details
13
-
14
- ### Model Description
15
-
16
- This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
17
- 1. RxRx3
18
- 2. JUMP-CP overexpression
19
- 3. JUMP-CP gene-knockouts
20
-
21
- - **Developed, funded, and shared by:** Recursion
22
- - **Model type:** Vision transformer CA-MAE
23
- - **Image modality:** Optimized for microscopy images from the CellPainting assay
24
- - **License:**
25
-
26
-
27
- ### Model Sources
28
-
29
- - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
30
- - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
31
-
32
-
33
- ## Uses
34
-
35
- NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
36
-
37
- 1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
38
- 2. Transform all the embeddings with that PCA kernel,
39
- 3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
40
-
41
- ### Direct Use
42
-
43
- - Create biologically useful embeddings of microscopy images
44
- - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
45
- - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
46
-
47
- ### Downstream Use
48
-
49
- - A determined ML expert could fine-tune the encoder for downstream tasks such as classification
50
-
51
- ### Out-of-Scope Use
52
-
53
- - Unlikely to be especially performant on brightfield microscopy images
54
- - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
55
-
56
- ## Bias, Risks, and Limitations
57
-
58
- - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
59
-
60
- ## How to Get Started with the Model
61
-
62
- You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
63
-
64
- ```python
65
- import pytest
66
- import torch
67
-
68
- from huggingface_mae import MAEModel
69
-
70
- huggingface_phenombeta_model_dir = "models/phenom_beta_huggingface"
71
- # huggingface_modelpath = "recursionpharma/test-pb-model"
72
-
73
-
74
- @pytest.fixture
75
- def huggingface_model():
76
- # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
77
- # huggingface-cli download recursionpharma/test-pb-model --local-dir=models/phenom_beta_huggingface
78
- huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
79
- huggingface_model.eval()
80
- return huggingface_model
81
-
82
-
83
- @pytest.mark.parametrize("C", [1, 4, 6, 11])
84
- @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
85
- def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
86
- example_input_array = torch.randint(
87
- low=0,
88
- high=255,
89
- size=(2, C, 256, 256),
90
- dtype=torch.uint8,
91
- device=huggingface_model.device,
92
- )
93
- huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
94
- embeddings = huggingface_model.predict(example_input_array)
95
- expected_output_dim = 384 * C if return_channelwise_embeddings else 384
96
- assert embeddings.shape == (2, expected_output_dim)
97
  ```
98
-
99
-
100
- ## Training, evaluation and testing details
101
-
102
- See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above.
103
-
104
-
105
- ## Environmental Impact
106
-
107
- - **Hardware Type:** Nvidia H100 Hopper nodes
108
- - **Hours used:** 400
109
- - **Cloud Provider:** private cloud
110
- - **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal)
111
-
112
- **BibTeX:**
113
-
114
- ```TeX
115
- @inproceedings{kraus2024masked,
116
- title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
117
- author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
118
- booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
119
- pages={11757--11768},
120
- year={2024}
121
- }
122
  ```
123
 
124
- ## Model Card Contact
 
125
 
126
- - Kian Kenyon-Dean: [email protected]
127
- - Oren Kraus: oren.kraus@recursion.com
128
- - Or, email: info@rxrx.ai
 
1
+ # Masked Autoencoders are Scalable Learners of Cellular Morphology
2
+ Official repo for Recursion's two recently accepted papers:
3
+ - Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology
4
+ - Paper: https://arxiv.org/abs/2404.10242
5
+ - CVPR poster page with video: https://cvpr.thecvf.com/virtual/2024/poster/31565
6
+ - Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio)
7
+ - Paper: https://arxiv.org/abs/2309.16064
8
 
9
+ ![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d)
10
 
 
 
11
 
12
+ ## Provided code
13
+ See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase.
14
 
15
+ Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  ```
17
+ import timm.models.vision_transformer as vit
18
+
19
+ def vit_base_patch16_256(**kwargs):
20
+ default_kwargs = dict(
21
+ img_size=256,
22
+ in_chans=6,
23
+ num_classes=0,
24
+ fc_norm=None,
25
+ class_token=True,
26
+ drop_path_rate=0.1,
27
+ init_values=0.0001,
28
+ block_fn=vit.ParallelScalingBlock,
29
+ qkv_bias=False,
30
+ qk_norm=True,
31
+ )
32
+ for k, v in kwargs.items():
33
+ default_kwargs[k] = v
34
+ return vit.vit_base_patch16_224(**default_kwargs)
 
 
 
 
 
 
35
  ```
36
 
37
+ ## Provided models
38
+ A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling: https://www.rxrx.ai/phenom
39
 
40
+ We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
41
+ - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
42
+ - https://www.youtube.com/watch?v=Gch6bX1toB0
config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ loss:
3
+ _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
4
+ reduction: none
5
+ optimizer:
6
+ _target_: timm.optim.lion.Lion
7
+ _partial_: true
8
+ lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L
9
+ weight_decay: 0.05
10
+ betas: [0.9, 0.95]
11
+ lr_scheduler:
12
+ _target_: torch.optim.lr_scheduler.OneCycleLR
13
+ _partial_: true
14
+ max_lr: @lr
15
+ pct_start: 0.1
16
+ anneal_strategy: cos
generate_reconstructions.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
huggingface_mae.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+
8
+ from loss import FourierLoss
9
+ from normalizer import Normalizer
10
+ from mae_modules import CAMAEDecoder, MAEDecoder, MAEEncoder
11
+ from mae_utils import flatten_images
12
+ from vit import (
13
+ generate_2d_sincos_pos_embeddings,
14
+ sincos_positional_encoding_vit,
15
+ vit_small_patch16_256,
16
+ )
17
+
18
+ TensorDict = Dict[str, torch.Tensor]
19
+
20
+
21
+ class MAEConfig(PretrainedConfig):
22
+ model_type = "MAE"
23
+
24
+ def __init__(
25
+ self,
26
+ mask_ratio=0.75,
27
+ encoder=None,
28
+ decoder=None,
29
+ loss=None,
30
+ optimizer=None,
31
+ input_norm=None,
32
+ fourier_loss=None,
33
+ fourier_loss_weight=0.0,
34
+ lr_scheduler=None,
35
+ use_MAE_weight_init=False,
36
+ crop_size=-1,
37
+ mask_fourier_loss=True,
38
+ return_channelwise_embeddings=False,
39
+ **kwargs,
40
+ ):
41
+ super().__init__(**kwargs)
42
+ self.mask_ratio = mask_ratio
43
+ self.encoder = encoder
44
+ self.decoder = decoder
45
+ self.loss = loss
46
+ self.optimizer = optimizer
47
+ self.input_norm = input_norm
48
+ self.fourier_loss = fourier_loss
49
+ self.fourier_loss_weight = fourier_loss_weight
50
+ self.lr_scheduler = lr_scheduler
51
+ self.use_MAE_weight_init = use_MAE_weight_init
52
+ self.crop_size = crop_size
53
+ self.mask_fourier_loss = mask_fourier_loss
54
+ self.return_channelwise_embeddings = return_channelwise_embeddings
55
+
56
+
57
+ class MAEModel(PreTrainedModel):
58
+ config_class = MAEConfig
59
+
60
+ # Loss metrics
61
+ TOTAL_LOSS = "loss"
62
+ RECON_LOSS = "reconstruction_loss"
63
+ FOURIER_LOSS = "fourier_loss"
64
+
65
+ def __init__(self, config: MAEConfig):
66
+ super().__init__(config)
67
+
68
+ self.mask_ratio = config.mask_ratio
69
+
70
+ # Could use Hydra to instantiate instead
71
+ self.encoder = MAEEncoder(
72
+ vit_backbone=sincos_positional_encoding_vit(
73
+ vit_backbone=vit_small_patch16_256(global_pool="avg")
74
+ ),
75
+ max_in_chans=11, # upper limit on number of input channels
76
+ channel_agnostic=True,
77
+ )
78
+ self.decoder = CAMAEDecoder(
79
+ depth=8,
80
+ embed_dim=512,
81
+ mlp_ratio=4,
82
+ norm_layer=nn.LayerNorm,
83
+ num_heads=16,
84
+ num_modalities=6,
85
+ qkv_bias=True,
86
+ tokens_per_modality=256,
87
+ )
88
+ self.input_norm = torch.nn.Sequential(
89
+ Normalizer(),
90
+ nn.InstanceNorm2d(None, affine=False, track_running_stats=False),
91
+ )
92
+
93
+ self.fourier_loss_weight = config.fourier_loss_weight
94
+ self.mask_fourier_loss = config.mask_fourier_loss
95
+ self.return_channelwise_embeddings = config.return_channelwise_embeddings
96
+ self.tokens_per_channel = 256 # hardcode the number of tokens per channel since we are patch16 crop 256
97
+
98
+ # loss stuff
99
+ self.loss = torch.nn.MSELoss(reduction="none")
100
+
101
+ self.fourier_loss = FourierLoss(num_multimodal_modalities=6)
102
+ if self.fourier_loss_weight > 0 and self.fourier_loss is None:
103
+ raise ValueError(
104
+ "FourierLoss weight is activated but no fourier_loss was defined in constructor"
105
+ )
106
+ elif self.fourier_loss_weight >= 1:
107
+ raise ValueError(
108
+ "FourierLoss weight is too large to do mixing factor, weight should be < 1"
109
+ )
110
+
111
+ self.patch_size = int(self.encoder.vit_backbone.patch_embed.patch_size[0])
112
+
113
+ # projection layer between the encoder and decoder
114
+ self.encoder_decoder_proj = nn.Linear(
115
+ self.encoder.embed_dim, self.decoder.embed_dim, bias=True
116
+ )
117
+
118
+ self.decoder_pred = nn.Linear(
119
+ self.decoder.embed_dim,
120
+ self.patch_size**2
121
+ * (1 if self.encoder.channel_agnostic else self.in_chans),
122
+ bias=True,
123
+ ) # linear layer from decoder embedding to input dims
124
+
125
+ # overwrite decoder pos embeddings based on encoder params
126
+ self.decoder.pos_embeddings = generate_2d_sincos_pos_embeddings( # type: ignore[assignment]
127
+ self.decoder.embed_dim,
128
+ length=self.encoder.vit_backbone.patch_embed.grid_size[0],
129
+ use_class_token=self.encoder.vit_backbone.cls_token is not None,
130
+ num_modality=(
131
+ self.decoder.num_modalities if self.encoder.channel_agnostic else 1
132
+ ),
133
+ )
134
+
135
+ if config.use_MAE_weight_init:
136
+ w = self.encoder.vit_backbone.patch_embed.proj.weight.data
137
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
138
+
139
+ torch.nn.init.normal_(self.encoder.vit_backbone.cls_token, std=0.02)
140
+ torch.nn.init.normal_(self.decoder.mask_token, std=0.02)
141
+
142
+ self.apply(self._MAE_init_weights)
143
+
144
+ def setup(self, stage: str) -> None:
145
+ super().setup(stage)
146
+
147
+ def _MAE_init_weights(self, m):
148
+ if isinstance(m, nn.Linear):
149
+ torch.nn.init.xavier_uniform_(m.weight)
150
+ if isinstance(m, nn.Linear) and m.bias is not None:
151
+ nn.init.constant_(m.bias, 0)
152
+ elif isinstance(m, nn.LayerNorm):
153
+ nn.init.constant_(m.bias, 0)
154
+ nn.init.constant_(m.weight, 1.0)
155
+
156
+ @staticmethod
157
+ def decode_to_reconstruction(
158
+ encoder_latent: torch.Tensor,
159
+ ind_restore: torch.Tensor,
160
+ proj: torch.nn.Module,
161
+ decoder: MAEDecoder | CAMAEDecoder,
162
+ pred: torch.nn.Module,
163
+ ) -> torch.Tensor:
164
+ """Feed forward the encoder latent through the decoders necessary projections and transformations."""
165
+ decoder_latent_projection = proj(
166
+ encoder_latent
167
+ ) # projection from encoder.embed_dim to decoder.embed_dim
168
+ decoder_tokens = decoder.forward_masked(
169
+ decoder_latent_projection, ind_restore
170
+ ) # decoder.embed_dim output
171
+ predicted_reconstruction = pred(
172
+ decoder_tokens
173
+ ) # linear projection to input dim
174
+ return predicted_reconstruction[:, 1:, :] # drop class token
175
+
176
+ def forward(
177
+ self, imgs: torch.Tensor, constant_noise: Union[torch.Tensor, None] = None
178
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
179
+ imgs = self.input_norm(imgs)
180
+ latent, mask, ind_restore = self.encoder.forward_masked(
181
+ imgs, self.mask_ratio, constant_noise
182
+ ) # encoder blocks
183
+ reconstruction = self.decode_to_reconstruction(
184
+ latent,
185
+ ind_restore,
186
+ self.encoder_decoder_proj,
187
+ self.decoder,
188
+ self.decoder_pred,
189
+ )
190
+ return latent, reconstruction, mask
191
+
192
+ def compute_MAE_loss(
193
+ self,
194
+ reconstruction: torch.Tensor,
195
+ img: torch.Tensor,
196
+ mask: torch.Tensor,
197
+ ) -> Tuple[torch.Tensor, Dict[str, float]]:
198
+ """Computes final loss and returns specific values of component losses for metric reporting."""
199
+ loss_dict = {}
200
+ img = self.input_norm(img)
201
+ target_flattened = flatten_images(
202
+ img,
203
+ patch_size=self.patch_size,
204
+ channel_agnostic=self.encoder.channel_agnostic,
205
+ )
206
+
207
+ loss: torch.Tensor = self.loss(
208
+ reconstruction, target_flattened
209
+ ) # should be with MSE or MAE (L1) with reduction='none'
210
+ loss = loss.mean(
211
+ dim=-1
212
+ ) # average over embedding dim -> mean loss per patch (N,L)
213
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches only
214
+ loss_dict[self.RECON_LOSS] = loss.item()
215
+
216
+ # compute fourier loss
217
+ if self.fourier_loss_weight > 0:
218
+ floss: torch.Tensor = self.fourier_loss(reconstruction, target_flattened)
219
+ if not self.mask_fourier_loss:
220
+ floss = floss.mean()
221
+ else:
222
+ floss = floss.mean(dim=-1)
223
+ floss = (floss * mask).sum() / mask.sum()
224
+
225
+ loss_dict[self.FOURIER_LOSS] = floss.item()
226
+
227
+ # here we use a mixing factor to keep the loss magnitude appropriate with fourier
228
+ if self.fourier_loss_weight > 0:
229
+ loss = (1 - self.fourier_loss_weight) * loss + (
230
+ self.fourier_loss_weight * floss
231
+ )
232
+ return loss, loss_dict
233
+
234
+ def training_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
235
+ img = batch["pixels"]
236
+ latent, reconstruction, mask = self(img.clone())
237
+ full_loss, loss_dict = self.compute_MAE_loss(reconstruction, img.float(), mask)
238
+ return {
239
+ "loss": full_loss,
240
+ **loss_dict, # type: ignore[dict-item]
241
+ }
242
+
243
+ def validation_step(self, batch: TensorDict, batch_idx: int) -> TensorDict:
244
+ return self.training_step(batch, batch_idx)
245
+
246
+ def update_metrics(self, outputs: TensorDict, batch: TensorDict) -> None:
247
+ self.metrics["lr"].update(value=self.lr_scheduler.get_last_lr())
248
+ for key, value in outputs.items():
249
+ if key.endswith("loss"):
250
+ self.metrics[key].update(value)
251
+
252
+ def on_validation_batch_end( # type: ignore[override]
253
+ self,
254
+ outputs: TensorDict,
255
+ batch: TensorDict,
256
+ batch_idx: int,
257
+ dataloader_idx: int = 0,
258
+ ) -> None:
259
+ super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
260
+
261
+ def predict(self, imgs: torch.Tensor) -> torch.Tensor:
262
+ imgs = self.input_norm(imgs)
263
+ X = self.encoder.vit_backbone.forward_features(
264
+ imgs
265
+ ) # 3d tensor N x num_tokens x dim
266
+ if self.return_channelwise_embeddings:
267
+ N, _, d = X.shape
268
+ num_channels = imgs.shape[1]
269
+ X_reshaped = X[:, 1:, :].view(N, num_channels, self.tokens_per_channel, d)
270
+ pooled_segments = X_reshaped.mean(
271
+ dim=2
272
+ ) # Resulting shape: (N, num_channels, d)
273
+ latent = pooled_segments.view(N, num_channels * d).contiguous()
274
+ else:
275
+ latent = X[:, 1:, :].mean(dim=1) # 1 + 256 * C tokens
276
+ return latent
277
+
278
+ def save_pretrained(self, save_directory: str, **kwargs):
279
+ filename = kwargs.pop("filename", "model.safetensors")
280
+ modelpath = f"{save_directory}/{filename}"
281
+ self.config.save_pretrained(save_directory)
282
+ torch.save({"state_dict": self.state_dict()}, modelpath)
283
+
284
+ @classmethod
285
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
286
+ filename = kwargs.pop("filename", "model.safetensors")
287
+
288
+ modelpath = f"{pretrained_model_name_or_path}/{filename}"
289
+ config = MAEConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
290
+ state_dict = torch.load(modelpath, map_location="cpu")
291
+ model = cls(config, *model_args, **kwargs)
292
+ model.load_state_dict(state_dict["state_dict"])
293
+ return model
loss.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class FourierLoss(nn.Module):
7
+ def __init__(
8
+ self,
9
+ use_l1_loss: bool = True,
10
+ num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE
11
+ ) -> None:
12
+ """
13
+ Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains
14
+ between the images / their radial histograms.
15
+
16
+ We will always set `reduction="none"` and enforce that the computation of any reductions from the
17
+ output of this loss be managed by the model under question.
18
+ """
19
+ super().__init__()
20
+ self.loss = (
21
+ nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
22
+ )
23
+ self.num_modalities = num_multimodal_modalities
24
+
25
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
26
+ # input = reconstructed image, target = original image
27
+ # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W)
28
+ flattened_images = len(input.shape) == len(target.shape) == 3
29
+ if flattened_images:
30
+ B, H_W, C = input.shape
31
+ H_W = H_W // self.num_modalities
32
+ four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5))
33
+ input = input.view(*four_d_shape)
34
+ target = target.view(*four_d_shape)
35
+ else:
36
+ B, C, h, w = input.shape
37
+ H_W = h * w
38
+
39
+ if len(input.shape) != len(target.shape) != 4:
40
+ raise ValueError(
41
+ f"Invalid input shape: got {input.shape} and {target.shape}."
42
+ )
43
+
44
+ fft_reconstructed = torch.fft.fft2(input)
45
+ fft_original = torch.fft.fft2(target)
46
+
47
+ magnitude_reconstructed = torch.abs(fft_reconstructed)
48
+ magnitude_original = torch.abs(fft_original)
49
+
50
+ loss_tensor: torch.Tensor = self.loss(
51
+ magnitude_reconstructed, magnitude_original
52
+ )
53
+
54
+ if (
55
+ flattened_images and not self.num_bins
56
+ ): # then output loss should be reshaped
57
+ loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
58
+
59
+ return loss_tensor
mae_modules.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ from functools import partial
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.helpers import checkpoint_seq
8
+ from timm.models.vision_transformer import Block, Mlp, VisionTransformer
9
+
10
+ from masking import transformer_random_masking
11
+ from vit import channel_agnostic_vit
12
+
13
+ # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should
14
+ # leverage the flattening and unflattening utilities as needed from mae_utils.py.
15
+ # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions.
16
+ # As described in the paper, images are self-standardized at the start.
17
+
18
+
19
+ class SelfStandardize(nn.Module):
20
+ def __init__(self) -> None:
21
+ super().__init__()
22
+ self.self_standardize = nn.LazyInstanceNorm2d(
23
+ affine=False, track_running_stats=False
24
+ )
25
+
26
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
27
+ x = pixels.float() / 255.0
28
+ return self.self_standardize(x)
29
+
30
+
31
+ class MAEEncoder(nn.Module):
32
+ def __init__(
33
+ self,
34
+ vit_backbone: VisionTransformer,
35
+ max_in_chans: int = 6,
36
+ channel_agnostic: bool = False,
37
+ ) -> None:
38
+ super().__init__()
39
+ if channel_agnostic:
40
+ self.vit_backbone = channel_agnostic_vit(
41
+ vit_backbone, max_in_chans=max_in_chans
42
+ )
43
+ else:
44
+ self.vit_backbone = vit_backbone
45
+ self.max_in_chans = max_in_chans
46
+ self.channel_agnostic = channel_agnostic
47
+
48
+ @property
49
+ def embed_dim(self) -> int:
50
+ return int(self.vit_backbone.embed_dim)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.vit_backbone.forward_features(x)
54
+ x = self.vit_backbone.forward_head(x)
55
+ return x # type: ignore[no-any-return]
56
+
57
+ def forward_masked(
58
+ self,
59
+ x: torch.Tensor,
60
+ mask_ratio: float,
61
+ constant_noise: Union[torch.Tensor, None] = None,
62
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
+ x = self.vit_backbone.patch_embed(x)
64
+ x = self.vit_backbone._pos_embed(x) # adds class token
65
+ x_ = x[:, 1:, :] # no class token
66
+ x_, mask, ind_restore = transformer_random_masking(
67
+ x_, mask_ratio, constant_noise
68
+ )
69
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
70
+ x = self.vit_backbone.norm_pre(x)
71
+
72
+ if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting():
73
+ x = checkpoint_seq(self.vit_backbone.blocks, x)
74
+ else:
75
+ x = self.vit_backbone.blocks(x)
76
+ x = self.vit_backbone.norm(x)
77
+ return x, mask, ind_restore
78
+
79
+
80
+ class MAEDecoder(nn.Module):
81
+ def __init__(
82
+ self,
83
+ embed_dim: int = 512,
84
+ depth: int = 8,
85
+ num_heads: int = 16,
86
+ mlp_ratio: float = 4,
87
+ qkv_bias: bool = True,
88
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
89
+ ) -> None:
90
+ super().__init__()
91
+ self.embed_dim = embed_dim
92
+ self.pos_embeddings = None # to be overwritten by MAE class
93
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
94
+ self.blocks = nn.Sequential(
95
+ *[
96
+ Block(
97
+ embed_dim,
98
+ num_heads,
99
+ mlp_ratio,
100
+ qkv_bias=qkv_bias,
101
+ norm_layer=norm_layer,
102
+ )
103
+ for i in range(depth)
104
+ ]
105
+ )
106
+ self.norm = norm_layer(embed_dim)
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ x = x + self.pos_embeddings
110
+ x = self.blocks(x)
111
+ x = self.norm(x)
112
+ return x # type: ignore[no-any-return]
113
+
114
+ def forward_masked(
115
+ self, x: torch.Tensor, ind_restore: torch.Tensor
116
+ ) -> torch.Tensor:
117
+ mask_tokens = self.mask_token.repeat(
118
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
119
+ )
120
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
121
+ x_ = torch.gather(
122
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
123
+ ) # unshuffle
124
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
125
+
126
+ x = x + self.pos_embeddings
127
+ x = self.blocks(x)
128
+ x = self.norm(x)
129
+ return x # type: ignore[no-any-return]
130
+
131
+
132
+ class CrossAttention(nn.Module):
133
+ def __init__(
134
+ self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0
135
+ ):
136
+ super().__init__()
137
+ self.num_heads = num_heads
138
+ head_dim = embed_dim // num_heads
139
+ self.scale = head_dim**-0.5
140
+
141
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
142
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
143
+
144
+ self.attn_drop = nn.Dropout(attn_drop)
145
+ self.proj = nn.Linear(embed_dim, embed_dim)
146
+ self.proj_drop = nn.Dropout(proj_drop)
147
+
148
+ def forward(self, x, context):
149
+ B, N, C = x.shape
150
+ _, M, _ = context.shape
151
+
152
+ q = (
153
+ self.q(x)
154
+ .reshape(B, N, self.num_heads, C // self.num_heads)
155
+ .permute(0, 2, 1, 3)
156
+ )
157
+ kv = (
158
+ self.kv(context)
159
+ .reshape(B, M, 2, self.num_heads, C // self.num_heads)
160
+ .permute(2, 0, 3, 1, 4)
161
+ )
162
+ k, v = kv[0], kv[1]
163
+
164
+ attn = (q @ k.transpose(-2, -1)) * self.scale
165
+ attn = attn.softmax(dim=-1)
166
+ attn = self.attn_drop(attn)
167
+
168
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
169
+ x = self.proj(x)
170
+ x = self.proj_drop(x)
171
+ return x
172
+
173
+
174
+ class CAMAEDecoder(nn.Module):
175
+ def __init__(
176
+ self,
177
+ num_modalities: int = 6,
178
+ tokens_per_modality: int = 256,
179
+ embed_dim: int = 256,
180
+ depth: int = 2,
181
+ num_heads: int = 16,
182
+ mlp_ratio: float = 4,
183
+ qkv_bias: bool = True,
184
+ norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment]
185
+ ) -> None:
186
+ super().__init__()
187
+ self.num_modalities = num_modalities
188
+ self.tokens_per_modality = tokens_per_modality
189
+ self.embed_dim = embed_dim
190
+ self.pos_embeddings = None # to be overwritten by MAE class
191
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
192
+ self.placeholder = nn.Parameter(
193
+ torch.zeros(1, 1, embed_dim), requires_grad=False
194
+ )
195
+ self.modality_tokens = nn.ParameterList(
196
+ [
197
+ nn.Parameter(torch.zeros(1, 1, self.embed_dim))
198
+ for modality in range(self.num_modalities)
199
+ ]
200
+ )
201
+
202
+ self.cross_attention = CrossAttention(embed_dim=self.embed_dim)
203
+ self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio))
204
+
205
+ self.decoders = nn.ModuleList(
206
+ [
207
+ nn.Sequential(
208
+ *[
209
+ Block(
210
+ embed_dim,
211
+ num_heads,
212
+ mlp_ratio,
213
+ qkv_bias=qkv_bias,
214
+ norm_layer=norm_layer,
215
+ )
216
+ for i in range(depth)
217
+ ]
218
+ )
219
+ for modality in range(self.num_modalities)
220
+ ]
221
+ )
222
+ # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm
223
+ self.context_norm = norm_layer(embed_dim)
224
+ self.query_norm = norm_layer(embed_dim)
225
+ self.out_norm = norm_layer(embed_dim)
226
+
227
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
228
+ x_m_s = []
229
+
230
+ modality_tokens_concat = torch.cat(
231
+ [
232
+ self.placeholder,
233
+ ] # placeholder for class token
234
+ + [
235
+ m_t.repeat(1, self.tokens_per_modality, 1)
236
+ for m_t in self.modality_tokens
237
+ ],
238
+ dim=1,
239
+ )
240
+
241
+ x = (
242
+ x + self.pos_embeddings + modality_tokens_concat
243
+ ) # add pos and tiled modality tokens
244
+ x_ = x[:, 1:, :] # no class token
245
+ for m, decoder in enumerate(
246
+ self.decoders
247
+ ): # iterate through modalities and decoders
248
+ x_m = x_[
249
+ :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, :
250
+ ]
251
+ x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_))
252
+ x_m = x_m + self.mlp(self.out_norm(x_m))
253
+ x_m = decoder(x_m)
254
+ x_m_s.append(x_m)
255
+ x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens
256
+ # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm
257
+ x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token
258
+
259
+ return x_m_s
260
+
261
+ def forward_masked(
262
+ self, x: torch.Tensor, ind_restore: torch.Tensor
263
+ ) -> torch.Tensor:
264
+ mask_tokens = self.mask_token.repeat(
265
+ x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1
266
+ )
267
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token
268
+ x_ = torch.gather(
269
+ x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
270
+ ) # unshuffle
271
+ x = torch.cat([x[:, :1, :], x_], dim=1) # add class token
272
+ x = self.forward(x)
273
+ return x
mae_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ import math
3
+
4
+ import torch
5
+
6
+
7
+ def flatten_images(
8
+ img: torch.Tensor, patch_size: int, channel_agnostic: bool = False
9
+ ) -> torch.Tensor:
10
+ """
11
+ Flattens 2D images into tokens with the same pixel values
12
+
13
+ Parameters
14
+ ----------
15
+ img : input image tensor (N, C, H, W)
16
+
17
+ Returns
18
+ -------
19
+ flattened_img: flattened image tensor (N, L, patch_size**2 * C)
20
+ """
21
+
22
+ if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0):
23
+ raise ValueError("image H must equal image W and be divisible by patch_size")
24
+ in_chans = img.shape[1]
25
+
26
+ h = w = int(img.shape[2] // patch_size)
27
+ x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size))
28
+
29
+ if channel_agnostic:
30
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ
31
+ x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2)))
32
+ else:
33
+ x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC
34
+ x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans)))
35
+ return x
36
+
37
+
38
+ def unflatten_tokens(
39
+ tokens: torch.Tensor,
40
+ patch_size: int,
41
+ num_modalities: int = 1,
42
+ channel_agnostic: bool = False,
43
+ ) -> torch.Tensor:
44
+ """
45
+ Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
46
+
47
+ Parameters
48
+ ----------
49
+ tokens : input token tensor (N,L,patch_size**2 * C)
50
+
51
+ Returns
52
+ -------
53
+ img: image tensor (N,C,H,W)
54
+ """
55
+ if num_modalities > 1 and not channel_agnostic:
56
+ raise ValueError("Multiple modalities requires channel agnostic unflattening.")
57
+
58
+ h = w = int(math.sqrt(tokens.shape[1] // num_modalities))
59
+ if h * w != (tokens.shape[1] // num_modalities):
60
+ raise ValueError("sqrt of number of tokens not integer")
61
+
62
+ if channel_agnostic:
63
+ x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size))
64
+ x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ
65
+ else:
66
+ x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1))
67
+ x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ
68
+ img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size))
69
+
70
+ return img
masking.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ def transformer_random_masking(
8
+ x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
9
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
10
+ """
11
+ Random mask patches per sample
12
+
13
+ Parameters
14
+ ----------
15
+ x : token tensor (N, L, D)
16
+ mask_ratio: float - ratio of image to mask
17
+ constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
18
+
19
+ Returns
20
+ -------
21
+ x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
22
+ mask : binary mask indicated masked tokens (1 where masked) (N, L)
23
+ ind_restore : locations of masked tokens, needed for decoder
24
+ """
25
+
26
+ N, L, D = x.shape # batch, length, dim
27
+ len_keep = int(L * (1 - mask_ratio))
28
+
29
+ # use random noise to generate batch based random masks
30
+ if constant_noise is not None:
31
+ noise = constant_noise
32
+ else:
33
+ noise = torch.rand(N, L, device=x.device)
34
+
35
+ shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
36
+ ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
37
+
38
+ # get masked input
39
+ tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
40
+ x_masked = torch.gather(
41
+ x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
42
+ )
43
+
44
+ # get binary mask used for loss masking: 0 is keep, 1 is remove
45
+ mask = torch.ones([N, L], device=x.device)
46
+ mask[:, :len_keep] = 0
47
+ mask = torch.gather(
48
+ mask, dim=1, index=ind_restore
49
+ ) # unshuffle to get the binary mask
50
+
51
+ return x_masked, mask, ind_restore
normalizer.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class Normalizer(torch.nn.Module):
5
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
6
+ pixels = pixels.float()
7
+ return pixels / 255.0
pyproject.toml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "maes_microscopy_project"
7
+ version = "0.1.0"
8
+ authors = [
9
+ {name = "kian-kd", email = "[email protected]"},
10
+ {name = "Laksh47", email = "[email protected]"},
11
+ ]
12
+ requires-python = ">=3.10.4"
13
+
14
+ dependencies = [
15
+ "huggingface-hub",
16
+ "timm",
17
+ "torch>=2.3",
18
+ "torchmetrics",
19
+ "torchvision",
20
+ "tqdm",
21
+ "transformers",
22
+ "xformers",
23
+ "zarr",
24
+ "pytorch-lightning>=2.1",
25
+ "matplotlib",
26
+ "scikit-image",
27
+ "ipykernel",
28
+ "isort",
29
+ "ruff",
30
+ "pytest",
31
+ ]
32
+
33
+ [tool.setuptools]
34
+ py-modules = []
sample/AA41_s1_1.jp2 ADDED
sample/AA41_s1_2.jp2 ADDED
sample/AA41_s1_3.jp2 ADDED
sample/AA41_s1_4.jp2 ADDED
sample/AA41_s1_5.jp2 ADDED
sample/AA41_s1_6.jp2 ADDED
test_huggingface_mae.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import torch
3
+
4
+ from huggingface_mae import MAEModel
5
+
6
+ huggingface_phenombeta_model_dir = "."
7
+ # huggingface_modelpath = "recursionpharma/test-pb-model"
8
+
9
+
10
+ @pytest.fixture
11
+ def huggingface_model():
12
+ # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
13
+ # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
14
+ huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
15
+ huggingface_model.eval()
16
+ return huggingface_model
17
+
18
+
19
+ @pytest.mark.parametrize("C", [1, 4, 6, 11])
20
+ @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
21
+ def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
22
+ example_input_array = torch.randint(
23
+ low=0,
24
+ high=255,
25
+ size=(2, C, 256, 256),
26
+ dtype=torch.uint8,
27
+ device=huggingface_model.device,
28
+ )
29
+ huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
30
+ embeddings = huggingface_model.predict(example_input_array)
31
+ expected_output_dim = 384 * C if return_channelwise_embeddings else 384
32
+ assert embeddings.shape == (2, expected_output_dim)
vit.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ import timm.models.vision_transformer as vit
3
+ import torch
4
+
5
+
6
+ def generate_2d_sincos_pos_embeddings(
7
+ embedding_dim: int,
8
+ length: int,
9
+ scale: float = 10000.0,
10
+ use_class_token: bool = True,
11
+ num_modality: int = 1,
12
+ ) -> torch.nn.Parameter:
13
+ """
14
+ Generate 2Dimensional sin/cosine positional embeddings
15
+
16
+ Parameters
17
+ ----------
18
+ embedding_dim : int
19
+ embedding dimension used in vit
20
+ length : int
21
+ number of tokens along height or width of image after patching (assuming square)
22
+ scale : float
23
+ scale for sin/cos functions
24
+ use_class_token : bool
25
+ True - add zero vector to be added to class_token, False - no vector added
26
+ num_modality: number of modalities. If 0, a single modality is assumed.
27
+ Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced.
28
+
29
+ Returns
30
+ -------
31
+ positional_encoding : torch.Tensor
32
+ positional encoding to add to vit patch encodings
33
+ [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim]
34
+ (w/ or w/o cls_token)
35
+ """
36
+
37
+ linear_positions = torch.arange(length, dtype=torch.float32)
38
+ height_mesh, width_mesh = torch.meshgrid(
39
+ linear_positions, linear_positions, indexing="ij"
40
+ )
41
+ positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
42
+ positional_weights = (
43
+ torch.arange(positional_dim, dtype=torch.float32) / positional_dim
44
+ )
45
+ positional_weights = 1.0 / (scale**positional_weights)
46
+
47
+ height_weights = torch.outer(height_mesh.flatten(), positional_weights)
48
+ width_weights = torch.outer(width_mesh.flatten(), positional_weights)
49
+
50
+ positional_encoding = torch.cat(
51
+ [
52
+ torch.sin(height_weights),
53
+ torch.cos(height_weights),
54
+ torch.sin(width_weights),
55
+ torch.cos(width_weights),
56
+ ],
57
+ dim=1,
58
+ )[None, :, :]
59
+
60
+ # repeat positional encoding for multiple channel modalities
61
+ positional_encoding = positional_encoding.repeat(1, num_modality, 1)
62
+
63
+ if use_class_token:
64
+ class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32)
65
+ positional_encoding = torch.cat([class_token, positional_encoding], dim=1)
66
+
67
+ positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False)
68
+
69
+ return positional_encoding
70
+
71
+
72
+ class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
73
+ def __init__(
74
+ self,
75
+ img_size: int,
76
+ patch_size: int,
77
+ embed_dim: int,
78
+ bias: bool = True,
79
+ ) -> None:
80
+ super().__init__(
81
+ img_size=img_size,
82
+ patch_size=patch_size,
83
+ in_chans=1, # in_chans is used by self.proj, which we override anyway
84
+ embed_dim=embed_dim,
85
+ norm_layer=None,
86
+ flatten=False,
87
+ bias=bias,
88
+ )
89
+ # channel-agnostic MAE has a single projection for all chans
90
+ self.proj = torch.nn.Conv2d(
91
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
92
+ )
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ in_chans = x.shape[1]
96
+ x = torch.stack(
97
+ [self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
98
+ ) # single project for all chans
99
+ x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
100
+ return x
101
+
102
+
103
+ class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
104
+ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
105
+ # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586
106
+ to_cat = []
107
+ if self.cls_token is not None:
108
+ to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
109
+
110
+ # TODO: upgrade timm to get access to register tokens
111
+ # if self.vit_backbone.reg_token is not None:
112
+ # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
113
+
114
+ # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs
115
+ # this supports having CA-MAEs actually be channel-agnostic at inference time
116
+ if self.no_embed_class:
117
+ x = x + self.pos_embed[:, : x.shape[1]]
118
+ if to_cat:
119
+ x = torch.cat(to_cat + [x], dim=1)
120
+ else:
121
+ if to_cat:
122
+ x = torch.cat(to_cat + [x], dim=1)
123
+ x = x + self.pos_embed[:, : x.shape[1]]
124
+ return self.pos_drop(x) # type: ignore[no-any-return]
125
+
126
+
127
+ def channel_agnostic_vit(
128
+ vit_backbone: vit.VisionTransformer, max_in_chans: int
129
+ ) -> vit.VisionTransformer:
130
+ # replace patch embedding with channel-agnostic version
131
+ vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
132
+ img_size=vit_backbone.patch_embed.img_size[0],
133
+ patch_size=vit_backbone.patch_embed.patch_size[0],
134
+ embed_dim=vit_backbone.embed_dim,
135
+ )
136
+
137
+ # replace positional embedding with channel-agnostic version
138
+ vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings(
139
+ embedding_dim=vit_backbone.embed_dim,
140
+ length=vit_backbone.patch_embed.grid_size[0],
141
+ use_class_token=vit_backbone.cls_token is not None,
142
+ num_modality=max_in_chans,
143
+ )
144
+
145
+ # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed
146
+ vit_backbone.__class__ = ChannelAgnosticViT
147
+ return vit_backbone
148
+
149
+
150
+ def sincos_positional_encoding_vit(
151
+ vit_backbone: vit.VisionTransformer, scale: float = 10000.0
152
+ ) -> vit.VisionTransformer:
153
+ """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model.
154
+
155
+ Parameters
156
+ ----------
157
+ vit_backbone : timm.models.vision_transformer.VisionTransformer
158
+ the constructed vision transformer from timm
159
+ scale : float (default 10000.0)
160
+ hyperparameter for sincos positional embeddings, recommend keeping at 10,000
161
+
162
+ Returns
163
+ -------
164
+ timm.models.vision_transformer.VisionTransformer
165
+ the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
166
+ """
167
+ # length: number of tokens along height or width of image after patching (assuming square)
168
+ length = (
169
+ vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
170
+ )
171
+ pos_embeddings = generate_2d_sincos_pos_embeddings(
172
+ vit_backbone.embed_dim,
173
+ length=length,
174
+ scale=scale,
175
+ use_class_token=vit_backbone.cls_token is not None,
176
+ )
177
+ # note, if the model had weight_init == 'skip', this might get overwritten
178
+ vit_backbone.pos_embed = pos_embeddings
179
+ return vit_backbone
180
+
181
+
182
+ def vit_small_patch16_256(**kwargs):
183
+ default_kwargs = dict(
184
+ img_size=256,
185
+ in_chans=6,
186
+ num_classes=0,
187
+ fc_norm=None,
188
+ class_token=True,
189
+ drop_path_rate=0.1,
190
+ init_values=0.0001,
191
+ block_fn=vit.ParallelScalingBlock,
192
+ qkv_bias=False,
193
+ qk_norm=True,
194
+ )
195
+ for k, v in kwargs.items():
196
+ default_kwargs[k] = v
197
+ return vit.vit_small_patch16_224(**default_kwargs)
198
+
199
+
200
+ def vit_small_patch32_512(**kwargs):
201
+ default_kwargs = dict(
202
+ img_size=512,
203
+ in_chans=6,
204
+ num_classes=0,
205
+ fc_norm=None,
206
+ class_token=True,
207
+ drop_path_rate=0.1,
208
+ init_values=0.0001,
209
+ block_fn=vit.ParallelScalingBlock,
210
+ qkv_bias=False,
211
+ qk_norm=True,
212
+ )
213
+ for k, v in kwargs.items():
214
+ default_kwargs[k] = v
215
+ return vit.vit_small_patch32_384(**default_kwargs)
216
+
217
+
218
+ def vit_base_patch8_256(**kwargs):
219
+ default_kwargs = dict(
220
+ img_size=256,
221
+ in_chans=6,
222
+ num_classes=0,
223
+ fc_norm=None,
224
+ class_token=True,
225
+ drop_path_rate=0.1,
226
+ init_values=0.0001,
227
+ block_fn=vit.ParallelScalingBlock,
228
+ qkv_bias=False,
229
+ qk_norm=True,
230
+ )
231
+ for k, v in kwargs.items():
232
+ default_kwargs[k] = v
233
+ return vit.vit_base_patch8_224(**default_kwargs)
234
+
235
+
236
+ def vit_base_patch16_256(**kwargs):
237
+ default_kwargs = dict(
238
+ img_size=256,
239
+ in_chans=6,
240
+ num_classes=0,
241
+ fc_norm=None,
242
+ class_token=True,
243
+ drop_path_rate=0.1,
244
+ init_values=0.0001,
245
+ block_fn=vit.ParallelScalingBlock,
246
+ qkv_bias=False,
247
+ qk_norm=True,
248
+ )
249
+ for k, v in kwargs.items():
250
+ default_kwargs[k] = v
251
+ return vit.vit_base_patch16_224(**default_kwargs)
252
+
253
+
254
+ def vit_base_patch32_512(**kwargs):
255
+ default_kwargs = dict(
256
+ img_size=512,
257
+ in_chans=6,
258
+ num_classes=0,
259
+ fc_norm=None,
260
+ class_token=True,
261
+ drop_path_rate=0.1,
262
+ init_values=0.0001,
263
+ block_fn=vit.ParallelScalingBlock,
264
+ qkv_bias=False,
265
+ qk_norm=True,
266
+ )
267
+ for k, v in kwargs.items():
268
+ default_kwargs[k] = v
269
+ return vit.vit_base_patch32_384(**default_kwargs)
270
+
271
+
272
+ def vit_large_patch8_256(**kwargs):
273
+ default_kwargs = dict(
274
+ img_size=256,
275
+ in_chans=6,
276
+ num_classes=0,
277
+ fc_norm=None,
278
+ class_token=True,
279
+ patch_size=8,
280
+ embed_dim=1024,
281
+ depth=24,
282
+ num_heads=16,
283
+ drop_path_rate=0.3,
284
+ init_values=0.0001,
285
+ block_fn=vit.ParallelScalingBlock,
286
+ qkv_bias=False,
287
+ qk_norm=True,
288
+ )
289
+ for k, v in kwargs.items():
290
+ default_kwargs[k] = v
291
+ return vit.VisionTransformer(**default_kwargs)
292
+
293
+
294
+ def vit_large_patch16_256(**kwargs):
295
+ default_kwargs = dict(
296
+ img_size=256,
297
+ in_chans=6,
298
+ num_classes=0,
299
+ fc_norm=None,
300
+ class_token=True,
301
+ drop_path_rate=0.3,
302
+ init_values=0.0001,
303
+ block_fn=vit.ParallelScalingBlock,
304
+ qkv_bias=False,
305
+ qk_norm=True,
306
+ )
307
+ for k, v in kwargs.items():
308
+ default_kwargs[k] = v
309
+ return vit.vit_large_patch16_384(**default_kwargs)
vit_encoder.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
+ from typing import Dict
3
+
4
+ import timm.models.vision_transformer as vit
5
+ import torch
6
+
7
+
8
+ def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]:
9
+ """This returns the prepped imagenet encoders from timm, not bad for microscopy data."""
10
+ vit_backbones = [
11
+ _make_vit(vit.vit_small_patch16_384),
12
+ _make_vit(vit.vit_base_patch16_384),
13
+ _make_vit(vit.vit_base_patch8_224),
14
+ _make_vit(vit.vit_large_patch16_384),
15
+ ]
16
+ model_names = [
17
+ "vit_small_patch16_384",
18
+ "vit_base_patch16_384",
19
+ "vit_base_patch8_224",
20
+ "vit_large_patch16_384",
21
+ ]
22
+ imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones))
23
+ return {name: model for name, model in zip(model_names, imagenet_encoders)}
24
+
25
+
26
+ def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule:
27
+ dummy_input = torch.testing.make_tensor(
28
+ (2, 6, 256, 256),
29
+ low=0,
30
+ high=255,
31
+ dtype=torch.uint8,
32
+ device=torch.device("cpu"),
33
+ )
34
+ encoder = torch.nn.Sequential(
35
+ Normalizer(),
36
+ torch.nn.LazyInstanceNorm2d(
37
+ affine=False, track_running_stats=False
38
+ ), # this module performs self-standardization, very important
39
+ vit_backbone,
40
+ ).to(device="cpu")
41
+ _ = encoder(dummy_input) # get those lazy modules built
42
+ return torch.jit.freeze(torch.jit.script(encoder.eval()))
43
+
44
+
45
+ def _make_vit(constructor):
46
+ return constructor(
47
+ pretrained=True, # download imagenet weights
48
+ img_size=256, # 256x256 crops
49
+ in_chans=6, # we expect 6-channel microscopy images
50
+ num_classes=0,
51
+ fc_norm=None,
52
+ class_token=True,
53
+ global_pool="avg", # minimal perf diff btwn "cls" and "avg"
54
+ )
55
+
56
+
57
+ class Normalizer(torch.nn.Module):
58
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
59
+ pixels = pixels.float()
60
+ pixels /= 255.0
61
+ return pixels