NASA GeneLab VisionTransformer on BPS Microscopy Data

NASA GeneLab VisionTransformer on BPS Microscopy Data

Authors:

Frank Soboczenski, University of York & King's College London, UK
Lauren Sanders, NASA Ames Research Center
Sylvain Costes, NASA Ames Resarch Center

General:

This Vision Transformer model has been fine-tuned on BPS Microscopy data. We are currently working on an extensive optimisation and evaluation framework. The images used are available here: Biological and Physical Sciences (BPS) Microscopy Benchmark Training Dataset or as a Huggingface dataset here: kenobi/GeneLab_BPS_BenchmarkData. This is a Vision Transformer model trained on Fluorescence microscopy images of individual nuclei from mouse fibroblast cells, to classofy DNA damage caused by cell irradiation with Fe particles or X-rays. We aim to highlight the ease of use of the HuggingFace platform, integration with popular deep learning frameworks such as PyTorch, TensorFlow, or JAX, performance monitoring with Weights and Biases, and the ability to effortlessly utilize pre-trained large scale Transformer models for targeted fine-tuning purposes. This is to our knowledge the first Vision Transformer model on NASA Genelab data and we are working on additional versions to address challenges in this domain.

We will include more technical details here soon.

Example Images

Use one of the images below for the inference API field on the upper right.

High_Energy_Ion_Fe_Nuclei

Right-click on this link (not the picture seen above) use 'save as'

XRay_irradiated_Nuclei

Right-click on this link (not the picture seen above) use 'save as'

ViT base training data (currently being replaced)

The ViT model was pretrained on a dataset consisting of 14 million images and 21k classes (ImageNet-21k. More information on the base model used can be found here: (https://huggingface.co/google/vit-base-patch16-224-in21k);

How to use this Model

(quick snippets to work on Google Colab)

First a snippet to downnload test images from an online repository:

import urllib.request

def download_image(url, filename):
    try:
        # Define custom headers
        headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
        }
        
        # Create a request with custom headers
        req = urllib.request.Request(url, headers=headers)
        
        # Open the URL and read the content
        with urllib.request.urlopen(req) as response:
            img_data = response.read()
            
        # Write the content to a file
        with open(filename, 'wb') as handler:
            handler.write(img_data)
        
        print(f"Image '{filename}' downloaded successfully")
    except Exception as e:
        print(f"Error downloading the image '{filename}':", e)

# List of URLs and corresponding filenames
urls = [
    ('https://roosevelt.devron-systems.com/HF/P242_73665006707-A6_002_008_proj.tif', 'P242_73665006707-A6_002_008_proj.tif'),
    ('https://roosevelt.devron-systems.com/HF/P278_73668090728-A7_003_027_proj.tif', 'P278_73668090728-A7_003_027_proj.tif')
]

# Download each image
for url, filename in urls:
    download_image(url, filename)

Then use the images for inference:

#!pip install transformers --quiet # uncomment this pip install for local use if you do not have transformers installed
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from PIL import Image

# Load the image
#image = Image.open('P242_73665006707-A6_002_008_proj.tif') #First Image
image = Image.open('P278_73668090728-A7_003_027_proj.tif')  #Second Image

# Convert grayscale image to RGB
image_rgb = image.convert("RGB")

# Load the pre-trained feature extractor and classification model
feature_extractor = AutoFeatureExtractor.from_pretrained("kenobi/NASA_GeneLab_MBT")
model = AutoModelForImageClassification.from_pretrained("kenobi/NASA_GeneLab_MBT")

# Extract features from the image
inputs = feature_extractor(images=image_rgb, return_tensors="pt")

# Perform classification
outputs = model(**inputs)
logits = outputs.logits

# Obtain the predicted class index and label
predicted_class_idx = logits.argmax(-1).item()
predicted_class_label = model.config.id2label[predicted_class_idx]

print("Predicted class:", predicted_class_label)

BibTeX & References

A publication on this work is currently in preparation. In the meantime, please refer to this model by using the following citation:

For the base ViT model used please refer to:

@misc{wu2020visual,
      title={Visual Transformers: Token-based Image Representation and Processing for Computer Vision}, 
      author={Bichen Wu and Chenfeng Xu and Xiaoliang Dai and Alvin Wan and Peizhao Zhang and Zhicheng Yan and Masayoshi Tomizuka and Joseph Gonzalez and Kurt Keutzer and Peter Vajda},
      year={2020},
      eprint={2006.03677},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

For referring to Imagenet:

@inproceedings{deng2009imagenet,
  title={Imagenet: A large-scale hierarchical image database},
  author={Deng, Jia and Dong, Wei and Socher, Richard and Li, Li-Jia and Li, Kai and Fei-Fei, Li},
  booktitle={2009 IEEE conference on computer vision and pattern recognition},
  pages={248--255},
  year={2009},
  organization={Ieee}
}
Downloads last month
15
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.