Flash Attention ESM2 (FAESM)
This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from FAPLM. Give us a star if you find it useful :)
Key Features
- Automatic Flash Attention: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available
- Smart Fallback: Automatically falls back to PyTorch SDPA if Flash Attention is not installed
- Drop-in Replacement: Same API as the original ESM2 models
- Memory Efficient: Removes padding tokens during computation for better efficiency
Installation Requirements
# Install PyTorch (if not already installed)
pip install torch
# Install Flash Attention (optional, for best performance)
pip install flash-attn --no-build-isolation --no-cache-dir
# Install huggingface
pip install transformers
Usage
One change the repo name and turn on trust_remote_code=True
.
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half()
tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D")
input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda")
output = model(input_ids)
print(output['logits'].shape)
print(output['last_hidden_state'].shape)
Supported ESM Versions
Model | Num Layers | Num Parameters |
---|---|---|
fredzzp/esm2_t36_3B_UR50D | 36 | 3B |
fredzzp/esm2_t33_650M_UR50D | 33 | 650M |
fredzzp/esm2_t30_150M_UR50D | 30 | 150M |
fredzzp/esm2_t12_35M_UR50D | 12 | 35M |
fredzzp/esm2_t6_8M_UR50D | 6 | 8M |
Citation
If you use this implementation, please cite both the original ESM2 paper and this work:
@article{lin2023evolutionary,
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others},
journal={Science},
volume={379},
number={6637},
pages={1123--1130},
year={2023},
publisher={American Association for the Advancement of Science}
}
@misc{faesm2024,
author = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors},
title = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
year = {2024},
howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
note = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
abstract = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}
License
This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.
- Downloads last month
- 7
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support