Flash Attention ESM2 (FAESM)

This is an efficient Flash Attention implementation of ESM2 (Evolutionary Scale Modeling) that provides nearly 50% speedup and memory reduction compared to the original implementation. All the source code is from FAPLM. Give us a star if you find it useful :)

Key Features

  • Automatic Flash Attention: Automatically uses FlashAttention for up to 70% faster inference and 60% memory reduction when available
  • Smart Fallback: Automatically falls back to PyTorch SDPA if Flash Attention is not installed
  • Drop-in Replacement: Same API as the original ESM2 models
  • Memory Efficient: Removes padding tokens during computation for better efficiency

Installation Requirements

# Install PyTorch (if not already installed)
pip install torch

# Install Flash Attention (optional, for best performance)
pip install flash-attn --no-build-isolation --no-cache-dir

# Install huggingface 
pip install transformers

Usage

One change the repo name and turn on trust_remote_code=True.

from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("fredzzp/esm2_t33_650M_UR50D", trust_remote_code=True).to("cuda").eval().half()
tokenizer = AutoTokenizer.from_pretrained("fredzzp/esm2_t33_650M_UR50D")

input_ids = tokenizer("AGC", return_tensors="pt").input_ids.to("cuda")
output = model(input_ids)
print(output['logits'].shape)
print(output['last_hidden_state'].shape)

Supported ESM Versions

Citation

If you use this implementation, please cite both the original ESM2 paper and this work:

@article{lin2023evolutionary,
  title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
  author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zihang and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yair and others},
  journal={Science},
  volume={379},
  number={6637},
  pages={1123--1130},
  year={2023},
  publisher={American Association for the Advancement of Science}
}

@misc{faesm2024,
  author       = {Fred Zhangzhi Peng, Pranam Chatterjee, and contributors},
  title        = {FAESM: An efficient PyTorch implementation of Evolutionary Scale Modeling (ESM)},
  year         = {2024},
  howpublished = {\url{https://github.com/pengzhangzhi/faesm}},
  note         = {Efficient PyTorch implementation of ESM with FlashAttention and Scalar Dot-Product Attention (SDPA)},
  abstract     = {FAESM is a drop-in replacement for the official ESM implementation, designed to save up to 60% memory usage and 70% inference time, while maintaining compatibility with the ESM API.},
}

License

This implementation is licensed under the MIT License. The ESM2 model weights maintain their original licensing terms.

Downloads last month
27
Safetensors
Model size
7.84M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support