LitGene: An Interpretable Transformer Model for Gene Representation Learning
LitGene is a transformer-based model that learns rich gene representations by integrating textual information from the scientific literature with structured knowledge from the Gene Ontology (GO). Using contrastive learning, the model refines gene embeddings that capture both sequence and functional annotations, enabling improved prediction of protein properties, gene-disease associations, and functional annotations such as GO terms and KEGG pathways.
This repository provides model weights for the pre-trained LitGene model. It is intended to serve as a base representation model that can be further adapted/fine-tuned for specific biomedical tasks.
Intended Usage
This model is intended to be used for any tasks that require interfacing with models . LitGene can be used for any of the following:
- Infrence: Providing predictions for gene functions, gene-disease/gene-protien associations, and specific biological pathway information. Prompt Ligene here.
- Gene Embeddings: Producing embeddings that capture both textual (literature based) sepcific biological properties of gene function.https://github.com/vinash85/LitGene/tree/master
- Fine-tuning: base representation model can be fine-tuned for a multitude of biomedical tasks (e.g. protien solubility prediction, drug dosage sensitivity). Example tasks can be found in this repo.
Usage (Pytorch)
Below is the example (pytorch) code to import LitGene weights
import torch
from transformers import AutoModel, AutoTokenizer
# Load the model and tokenizer
model_name = "tumorailab/LitGene_ContrastiveLearning"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# If you want to move the model to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
below is example code to get embeddings for an example scentence
# Prepare your sentence
sentence = "Your text goes here"
# Tokenize the sentence
inputs = tokenizer(
sentence,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
# Move inputs to the same device as model
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get embeddings
with torch.no_grad():
model.eval()
outputs = model(**inputs)
# Get the CLS token embedding (first token)
print(outputs.last_hidden_state)
Training Details
Hyperparameters
Hyperparameter | Value |
---|---|
Embedding Dimension | 768 |
Batch Size | 64 |
Optimizer | AdamW |
Learning Rate | 2e-5 (with linear decay) |
Weight Decay | 0.01 |
Contrastive Learning Loss Function | Margin-based ranking loss |
Contrastive Loss Margin (δ) | 0.5 |
Number of Training Steps | 100k |
Dropout Rate | 0.1 |
Gradient Clipping | 1.0 |
- Downloads last month
- 1