Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,275 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- colabfit/MD22_buckyball_catcher
|
5 |
+
- colabfit/MD22_AT_AT
|
6 |
+
- colabfit/MD22_stachyose
|
7 |
+
- colabfit/MD22_AT_AT_CG_CG
|
8 |
+
- colabfit/MD22_Ac_Ala3_NHMe
|
9 |
+
- colabfit/MD22_DHA
|
10 |
+
- colabfit/MD22_double_walled_nanotube
|
11 |
+
- yairschiff/qm9
|
12 |
+
- maomlab/Molecule3D
|
13 |
+
metrics:
|
14 |
+
- mae
|
15 |
+
tags:
|
16 |
+
- equivariant
|
17 |
+
- graph neural network
|
18 |
+
- molecular property prediction
|
19 |
+
---
|
20 |
+
|
21 |
+
# GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
|
22 |
+
|
23 |
+
<div align="center">
|
24 |
+
|
25 |
+
[](https://openreview.net/pdf?id=5wxCQDtbMo)
|
26 |
+
[](https://www.sarpaykent.com/publications/gotennet/)
|
27 |
+
[](LICENSE)
|
28 |
+
[](https://pypi.org/project/gotennet/)
|
29 |
+
[](https://pytorch.org/)
|
30 |
+
|
31 |
+
</div>
|
32 |
+
|
33 |
+
<p align="center">
|
34 |
+
<img src="https://raw.githubusercontent.com/sarpaykent/GotenNet/refs/heads/main/assets/GotenNet_framework.png" width="800">
|
35 |
+
</p>
|
36 |
+
|
37 |
+
## Overview
|
38 |
+
|
39 |
+
This is the official implementation of **"GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks"** published at ICLR 2025.
|
40 |
+
|
41 |
+
GotenNet introduces a novel framework for modeling 3D molecular structures that achieves state-of-the-art performance while maintaining computational efficiency. Our approach balances expressiveness and efficiency through innovative tensor-based representations and attention mechanisms.
|
42 |
+
|
43 |
+
## Table of Contents
|
44 |
+
- [✨ Key Features](#-key-features)
|
45 |
+
- [🚀 Installation](#-installation)
|
46 |
+
- [📦 From PyPI (Recommended)](#-from-pypi-recommended)
|
47 |
+
- [🔧 From Source](#🔧-from-source)
|
48 |
+
- [🔬 Usage](#🔬-usage)
|
49 |
+
- [Using the Model](#using-the-model)
|
50 |
+
- [Loading Pre-trained Models Programmatically](#loading-pre-trained-models-programmatically)
|
51 |
+
- [Training a Model](#training-a-model)
|
52 |
+
- [Testing a Model](#testing-a-model)
|
53 |
+
- [Configuration](#configuration)
|
54 |
+
- [🤝 Contributing](#-contributing)
|
55 |
+
- [📚 Citation](#-citation)
|
56 |
+
- [📄 License](#-license)
|
57 |
+
- [Acknowledgements](#acknowledgements)
|
58 |
+
|
59 |
+
## ✨ Key Features
|
60 |
+
|
61 |
+
- 🔄 **Effective Geometric Tensor Representations**: Leverages geometric tensors without relying on irreducible representations or Clebsch-Gordan transforms
|
62 |
+
- 🧩 **Unified Structural Embedding**: Introduces geometry-aware tensor attention for improved molecular representation
|
63 |
+
- 📊 **Hierarchical Tensor Refinement**: Implements a flexible and efficient representation scheme
|
64 |
+
- 🏆 **State-of-the-Art Performance**: Achieves superior results on QM9, rMD17, MD22, and Molecule3D datasets
|
65 |
+
- 📈 **Load Pre-trained Models**: Easily load and use pre-trained model checkpoints by name, URL, or local path, with automatic download capabilities.
|
66 |
+
|
67 |
+
## 🚀 Installation
|
68 |
+
|
69 |
+
### 📦 From PyPI (Recommended)
|
70 |
+
|
71 |
+
You can install it using pip:
|
72 |
+
|
73 |
+
* **Core Model Only:** Installs only the essential dependencies required to use the `GotenNet` model.
|
74 |
+
```bash
|
75 |
+
pip install gotennet
|
76 |
+
```
|
77 |
+
|
78 |
+
* **Full Installation (Core + Training/Utilities):** Installs core dependencies plus libraries needed for training, data handling, logging, etc.
|
79 |
+
```bash
|
80 |
+
pip install gotennet[full]
|
81 |
+
```
|
82 |
+
|
83 |
+
### 🔧 From Source
|
84 |
+
|
85 |
+
1. **Clone the repository:**
|
86 |
+
```bash
|
87 |
+
git clone https://github.com/sarpaykent/gotennet.git
|
88 |
+
cd gotennet
|
89 |
+
```
|
90 |
+
|
91 |
+
2. **Create and activate a virtual environment** (using conda or venv/uv):
|
92 |
+
```bash
|
93 |
+
# Using conda
|
94 |
+
conda create -n gotennet python=3.10
|
95 |
+
conda activate gotennet
|
96 |
+
|
97 |
+
# Or using venv/uv
|
98 |
+
uv venv --python 3.10
|
99 |
+
source .venv/bin/activate
|
100 |
+
```
|
101 |
+
|
102 |
+
3. **Install the package:**
|
103 |
+
Choose the installation type based on your needs:
|
104 |
+
|
105 |
+
* **Core Model Only:** Installs only the essential dependencies required to use the `GotenNet` model.
|
106 |
+
```bash
|
107 |
+
pip install .
|
108 |
+
```
|
109 |
+
|
110 |
+
* **Full Installation (Core + Training/Utilities):** Installs core dependencies plus libraries needed for training, data handling, logging, etc.
|
111 |
+
```bash
|
112 |
+
pip install .[full]
|
113 |
+
# Or for editable install:
|
114 |
+
# pip install -e .[full]
|
115 |
+
```
|
116 |
+
*(Note: `uv` can be used as a faster alternative to `pip` for installation, e.g., `uv pip install .[full]`)*
|
117 |
+
|
118 |
+
## 🔬 Usage
|
119 |
+
|
120 |
+
### Using the Model
|
121 |
+
|
122 |
+
Once installed, you can import and use the `GotenNet` model directly in your Python code:
|
123 |
+
|
124 |
+
```python
|
125 |
+
from gotennet import GotenNet
|
126 |
+
|
127 |
+
# --- Using the base GotenNet model ---
|
128 |
+
# Requires manual calculation of edge_index, edge_diff, edge_vec
|
129 |
+
|
130 |
+
# Example instantiation
|
131 |
+
model = GotenNet(
|
132 |
+
n_atom_basis=256,
|
133 |
+
n_interactions=4,
|
134 |
+
# resf of the parameters
|
135 |
+
)
|
136 |
+
|
137 |
+
# Encoded representations can be computed with
|
138 |
+
h, X = model(atomic_numbers, edge_index, edge_diff, edge_vec)
|
139 |
+
|
140 |
+
# --- Using GotenNetWrapper (handles distance calculation) ---
|
141 |
+
# Expects a PyTorch Geometric Data object or similar dict
|
142 |
+
# with keys like 'z' (atomic_numbers), 'pos' (positions), 'batch'
|
143 |
+
|
144 |
+
# Example instantiation
|
145 |
+
from gotennet import GotenNetWrapper
|
146 |
+
wrapped_model = GotenNetWrapper(
|
147 |
+
n_atom_basis=256,
|
148 |
+
n_interactions=4,
|
149 |
+
# rest of the parameters
|
150 |
+
)
|
151 |
+
|
152 |
+
# Encoded representations can be computed with
|
153 |
+
h, X = wrapped_model(data)
|
154 |
+
|
155 |
+
```
|
156 |
+
|
157 |
+
### Loading Pre-trained Models Programmatically
|
158 |
+
|
159 |
+
You can easily load pre-trained `GotenModel` instances programmatically using the `from_pretrained` class method. This method can accept a model alias (which will be resolved to a download URL), a direct HTTPS URL to a checkpoint file, or a local file path. It handles automatic downloading and caching of checkpoints. Pre-trained model weights and aliases are hosted on the [GotenNet Hugging Face Model Hub](https://huggingface.co/sarpaykent/GotenNet).
|
160 |
+
|
161 |
+
```python
|
162 |
+
from gotennet.models import GotenModel
|
163 |
+
|
164 |
+
# Example 1: Load by model alias
|
165 |
+
# This will automatically download from a known location if not found locally.
|
166 |
+
# The format is {dataset}_{size}_{target}
|
167 |
+
model_by_alias = GotenModel.from_pretrained("QM9_small_homo")
|
168 |
+
|
169 |
+
# Example 2: Load from a direct URL
|
170 |
+
model_url = "https://huggingface.co/sarpaykent/GotenNet/resolve/main/pretrained/qm9/small/gotennet_homo.ckpt" # Replace with an actual URL
|
171 |
+
model_by_url = GotenModel.from_pretrained(model_url)
|
172 |
+
|
173 |
+
# Example 3: Load from a local file path
|
174 |
+
local_model_path = "/path/to/your/local_model.ckpt"
|
175 |
+
model_by_path = GotenModel.from_pretrained(local_model_path)
|
176 |
+
|
177 |
+
# After loading, the model is ready for inference:
|
178 |
+
predictions = model_by_alias(data_input)
|
179 |
+
```
|
180 |
+
|
181 |
+
For more advanced scenarios, if you only need to load the base `GotenNet` representation module from a local checkpoint (e.g., a checkpoint that only contains representation weights), you can use:
|
182 |
+
|
183 |
+
```python
|
184 |
+
from gotennet.models.representation import GotenNet, GotenNetWrapper
|
185 |
+
|
186 |
+
# Example: Load a GotenNet representation from a local file
|
187 |
+
representation_checkpoint_path = "/path/to/your/local_model.ckpt"
|
188 |
+
gotennet_model = GotenNet.load_from_checkpoint(representation_checkpoint_path)
|
189 |
+
# or
|
190 |
+
gotennet_wrapped = GotenNetWrapper.load_from_checkpoint(representation_checkpoint_path)
|
191 |
+
```
|
192 |
+
|
193 |
+
### Training a Model
|
194 |
+
|
195 |
+
After installation, you can use the `train_gotennet` command:
|
196 |
+
|
197 |
+
```bash
|
198 |
+
train_gotennet
|
199 |
+
```
|
200 |
+
|
201 |
+
Or you can run the training script directly:
|
202 |
+
|
203 |
+
```bash
|
204 |
+
python gotennet/scripts/train.py
|
205 |
+
```
|
206 |
+
|
207 |
+
Both methods use Hydra for configuration. You can reproduce U0 target prediction on the QM9 dataset with the following command:
|
208 |
+
|
209 |
+
```bash
|
210 |
+
train_gotennet experiment=qm9_u0.yaml
|
211 |
+
```
|
212 |
+
|
213 |
+
### Testing a Model
|
214 |
+
|
215 |
+
To evaluate a trained model, you can use the `test_gotennet` script. When you provide a checkpoint, the script can infer necessary configurations (like dataset and task details) directly from the checkpoint file. This script leverages the `GotenModel.from_pretrained` capabilities, allowing you to specify the model to test by its alias, a direct URL, or a local file path, handling automatic downloads.
|
216 |
+
|
217 |
+
Here's how you can use it:
|
218 |
+
|
219 |
+
```bash
|
220 |
+
# Option 1: Test by model alias (e.g., QM9_small_homo)
|
221 |
+
# The script will automatically download the checkpoint and infer configurations.
|
222 |
+
test_gotennet checkpoint=QM9_small_homo
|
223 |
+
|
224 |
+
# Option 2: Test with a direct checkpoint URL
|
225 |
+
# The script will automatically download the checkpoint and infer configurations.
|
226 |
+
test_gotennet checkpoint=https://huggingface.co/sarpaykent/GotenNet/resolve/main/pretrained/qm9/small/gotennet_homo.ckpt
|
227 |
+
|
228 |
+
# Option 3: Test with a local checkpoint file path
|
229 |
+
test_gotennet checkpoint=/path/to/your/local_model.ckpt
|
230 |
+
```
|
231 |
+
|
232 |
+
The script uses [Hydra](https://hydra.cc/) for any additional or overriding configurations if needed, but for straightforward evaluation of a checkpoint, only the `checkpoint` argument is typically required.
|
233 |
+
|
234 |
+
### Configuration
|
235 |
+
|
236 |
+
The project uses [Hydra](https://hydra.cc/) for configuration management. Configuration files are located in the `configs/` directory.
|
237 |
+
|
238 |
+
Main configuration categories:
|
239 |
+
- `datamodule`: Dataset configurations (md17, qm9, etc.)
|
240 |
+
- `model`: Model configurations
|
241 |
+
- `trainer`: Training parameters
|
242 |
+
- `callbacks`: Callback configurations
|
243 |
+
- `logger`: Logging configurations
|
244 |
+
|
245 |
+
## 🤝 Contributing
|
246 |
+
|
247 |
+
We welcome contributions to GotenNet! Please feel free to submit a Pull Request.
|
248 |
+
|
249 |
+
|
250 |
+
## 📚 Citation
|
251 |
+
|
252 |
+
Please consider citing our work below if this project is helpful:
|
253 |
+
|
254 |
+
|
255 |
+
```bibtex
|
256 |
+
@inproceedings{aykent2025gotennet,
|
257 |
+
author = {Aykent, Sarp and Xia, Tian},
|
258 |
+
booktitle = {The Thirteenth International Conference on LearningRepresentations},
|
259 |
+
year = {2025},
|
260 |
+
title = {{GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks}},
|
261 |
+
url = {https://openreview.net/forum?id=5wxCQDtbMo},
|
262 |
+
howpublished = {https://openreview.net/forum?id=5wxCQDtbMo},
|
263 |
+
}
|
264 |
+
```
|
265 |
+
|
266 |
+
## 📄 License
|
267 |
+
|
268 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
269 |
+
|
270 |
+
## Acknowledgements
|
271 |
+
|
272 |
+
GotenNet is proudly built on the innovative foundations provided by the projects below.
|
273 |
+
- [e3nn](https://github.com/e3nn/e3nn)
|
274 |
+
- [PyG](https://github.com/pyg-team/pytorch_geometric)
|
275 |
+
- [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning)
|