Flash Attention ESM2 (FAESM)

This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from FAPLM. Give us a star if you find it useful :)

Key Features

  • Automatic Flash Attention: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available
  • Smart Fallback: Automatically falls back to PyTorch SDPA if Flash Attention is not installed
  • Drop-in Replacement: Same API as the original ESM2 models
  • Memory Efficient: Removes padding tokens during computation for better efficiency

Installation Requirements

# Install PyTorch (if not already installed)
pip install torch

# Install Flash Attention (optional, for best performance)
pip install flash-attn --no-build-isolation --no-cache-dir

# Install huggingface 
pip install transformers

Usage

One change the repo name and turn on trust_remote_code=True.

from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half()
tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D")

input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda")
output = model(input_ids)
print(output['logits'].shape)
print(output['last_hidden_state'].shape)

Supported ESM Versions

Citation

If you use this implementation, please cite both the original ESM2 paper and this work:

@article{lin2023evolutionary,
  title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
  author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others},
  journal={Science},
  volume={379},
  number={6637},
  pages={1123--1130},
  year={2023},
  publisher={American Association for the Advancement of Science}
}

@misc{faesm2024,
  author       = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors},
  title        = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
  year         = {2024},
  howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
  note         = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
  abstract     = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}

License

This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.

Downloads last month
7
Safetensors
Model size
2.84B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support