DDSC: Masked Discrete Diffusion for Gene Expression

Members

  • Justin Jung (CZII)

Code:

Code is made public at https://github.com/jsjung00/diffusePerturb. A more detailed write up is also linked there.

Overview and motivation

Gene expression count vectors are tricky to navigate with machine learning models: they have a very peaked distribution around small values but it is also long tailed. Moreover, the count distributions can be cell dependent and higher variable per sample. This makes it challenging to train a model to predict count vectors. When predicting raw counts and moreover predicting the mean (say under L2 loss), without hand crafted loss function, the model will be dominated by the large counts and will not be robust to errors of small counts.

For these reasons, it may be natural to instead consider gene expression count ranks. Moreover, we recognize that ranks have a natural categorical definition, as we define up to n ranks or buckets.

With this motivation, we explore framing the problem with a "sequence to sequence" lens. Additionally, we frame it as a generative modeling problem: given some sequence of some count ranks, where some count rank tokens are missing ([MASK]), we aim to predict or inpaint the "masked" count ranks.

This fits well with a masked language discrete diffusion model, which we train on the Tahoe 100M dataset.

Methods

  • We train a conditional masked language discrete diffusion model using the Tahoe100M gene expression data.
  • We have trained the model for 70k batches or approximately 4.5 cell samples.
  • We condition on the gene IDs. Moreover, we allow for conditional masked inpainting generation at inference time

Discussion

We believe that modelling the gene expression prediction problem as "seq to sequence" framework and moreover from a generative modelling lens is an exciting and promising approach which allows for controllable gene expression generation.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train jsjung00/DDSC