sarpaykent commited on
Commit
0281c63
·
verified ·
1 Parent(s): f5ae9ee

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +275 -3
README.md CHANGED
@@ -1,3 +1,275 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - colabfit/MD22_buckyball_catcher
5
+ - colabfit/MD22_AT_AT
6
+ - colabfit/MD22_stachyose
7
+ - colabfit/MD22_AT_AT_CG_CG
8
+ - colabfit/MD22_Ac_Ala3_NHMe
9
+ - colabfit/MD22_DHA
10
+ - colabfit/MD22_double_walled_nanotube
11
+ - yairschiff/qm9
12
+ - maomlab/Molecule3D
13
+ metrics:
14
+ - mae
15
+ tags:
16
+ - equivariant
17
+ - graph neural network
18
+ - molecular property prediction
19
+ ---
20
+
21
+ # GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks
22
+
23
+ <div align="center">
24
+
25
+ [![Paper](https://img.shields.io/badge/Paper-ICLR%202025-blue)](https://openreview.net/pdf?id=5wxCQDtbMo)
26
+ [![Project Page](https://img.shields.io/badge/Project-Website-green)](https://www.sarpaykent.com/publications/gotennet/)
27
+ [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
28
+ [![PyPI - Version](https://img.shields.io/pypi/v/gotennet)](https://pypi.org/project/gotennet/)
29
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
30
+
31
+ </div>
32
+
33
+ <p align="center">
34
+ <img src="https://raw.githubusercontent.com/sarpaykent/GotenNet/refs/heads/main/assets/GotenNet_framework.png" width="800">
35
+ </p>
36
+
37
+ ## Overview
38
+
39
+ This is the official implementation of **"GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks"** published at ICLR 2025.
40
+
41
+ GotenNet introduces a novel framework for modeling 3D molecular structures that achieves state-of-the-art performance while maintaining computational efficiency. Our approach balances expressiveness and efficiency through innovative tensor-based representations and attention mechanisms.
42
+
43
+ ## Table of Contents
44
+ - [✨ Key Features](#-key-features)
45
+ - [🚀 Installation](#-installation)
46
+ - [📦 From PyPI (Recommended)](#-from-pypi-recommended)
47
+ - [🔧 From Source](#🔧-from-source)
48
+ - [🔬 Usage](#🔬-usage)
49
+ - [Using the Model](#using-the-model)
50
+ - [Loading Pre-trained Models Programmatically](#loading-pre-trained-models-programmatically)
51
+ - [Training a Model](#training-a-model)
52
+ - [Testing a Model](#testing-a-model)
53
+ - [Configuration](#configuration)
54
+ - [🤝 Contributing](#-contributing)
55
+ - [📚 Citation](#-citation)
56
+ - [📄 License](#-license)
57
+ - [Acknowledgements](#acknowledgements)
58
+
59
+ ## ✨ Key Features
60
+
61
+ - 🔄 **Effective Geometric Tensor Representations**: Leverages geometric tensors without relying on irreducible representations or Clebsch-Gordan transforms
62
+ - 🧩 **Unified Structural Embedding**: Introduces geometry-aware tensor attention for improved molecular representation
63
+ - 📊 **Hierarchical Tensor Refinement**: Implements a flexible and efficient representation scheme
64
+ - 🏆 **State-of-the-Art Performance**: Achieves superior results on QM9, rMD17, MD22, and Molecule3D datasets
65
+ - 📈 **Load Pre-trained Models**: Easily load and use pre-trained model checkpoints by name, URL, or local path, with automatic download capabilities.
66
+
67
+ ## 🚀 Installation
68
+
69
+ ### 📦 From PyPI (Recommended)
70
+
71
+ You can install it using pip:
72
+
73
+ * **Core Model Only:** Installs only the essential dependencies required to use the `GotenNet` model.
74
+ ```bash
75
+ pip install gotennet
76
+ ```
77
+
78
+ * **Full Installation (Core + Training/Utilities):** Installs core dependencies plus libraries needed for training, data handling, logging, etc.
79
+ ```bash
80
+ pip install gotennet[full]
81
+ ```
82
+
83
+ ### 🔧 From Source
84
+
85
+ 1. **Clone the repository:**
86
+ ```bash
87
+ git clone https://github.com/sarpaykent/gotennet.git
88
+ cd gotennet
89
+ ```
90
+
91
+ 2. **Create and activate a virtual environment** (using conda or venv/uv):
92
+ ```bash
93
+ # Using conda
94
+ conda create -n gotennet python=3.10
95
+ conda activate gotennet
96
+
97
+ # Or using venv/uv
98
+ uv venv --python 3.10
99
+ source .venv/bin/activate
100
+ ```
101
+
102
+ 3. **Install the package:**
103
+ Choose the installation type based on your needs:
104
+
105
+ * **Core Model Only:** Installs only the essential dependencies required to use the `GotenNet` model.
106
+ ```bash
107
+ pip install .
108
+ ```
109
+
110
+ * **Full Installation (Core + Training/Utilities):** Installs core dependencies plus libraries needed for training, data handling, logging, etc.
111
+ ```bash
112
+ pip install .[full]
113
+ # Or for editable install:
114
+ # pip install -e .[full]
115
+ ```
116
+ *(Note: `uv` can be used as a faster alternative to `pip` for installation, e.g., `uv pip install .[full]`)*
117
+
118
+ ## 🔬 Usage
119
+
120
+ ### Using the Model
121
+
122
+ Once installed, you can import and use the `GotenNet` model directly in your Python code:
123
+
124
+ ```python
125
+ from gotennet import GotenNet
126
+
127
+ # --- Using the base GotenNet model ---
128
+ # Requires manual calculation of edge_index, edge_diff, edge_vec
129
+
130
+ # Example instantiation
131
+ model = GotenNet(
132
+ n_atom_basis=256,
133
+ n_interactions=4,
134
+ # resf of the parameters
135
+ )
136
+
137
+ # Encoded representations can be computed with
138
+ h, X = model(atomic_numbers, edge_index, edge_diff, edge_vec)
139
+
140
+ # --- Using GotenNetWrapper (handles distance calculation) ---
141
+ # Expects a PyTorch Geometric Data object or similar dict
142
+ # with keys like 'z' (atomic_numbers), 'pos' (positions), 'batch'
143
+
144
+ # Example instantiation
145
+ from gotennet import GotenNetWrapper
146
+ wrapped_model = GotenNetWrapper(
147
+ n_atom_basis=256,
148
+ n_interactions=4,
149
+ # rest of the parameters
150
+ )
151
+
152
+ # Encoded representations can be computed with
153
+ h, X = wrapped_model(data)
154
+
155
+ ```
156
+
157
+ ### Loading Pre-trained Models Programmatically
158
+
159
+ You can easily load pre-trained `GotenModel` instances programmatically using the `from_pretrained` class method. This method can accept a model alias (which will be resolved to a download URL), a direct HTTPS URL to a checkpoint file, or a local file path. It handles automatic downloading and caching of checkpoints. Pre-trained model weights and aliases are hosted on the [GotenNet Hugging Face Model Hub](https://huggingface.co/sarpaykent/GotenNet).
160
+
161
+ ```python
162
+ from gotennet.models import GotenModel
163
+
164
+ # Example 1: Load by model alias
165
+ # This will automatically download from a known location if not found locally.
166
+ # The format is {dataset}_{size}_{target}
167
+ model_by_alias = GotenModel.from_pretrained("QM9_small_homo")
168
+
169
+ # Example 2: Load from a direct URL
170
+ model_url = "https://huggingface.co/sarpaykent/GotenNet/resolve/main/pretrained/qm9/small/gotennet_homo.ckpt" # Replace with an actual URL
171
+ model_by_url = GotenModel.from_pretrained(model_url)
172
+
173
+ # Example 3: Load from a local file path
174
+ local_model_path = "/path/to/your/local_model.ckpt"
175
+ model_by_path = GotenModel.from_pretrained(local_model_path)
176
+
177
+ # After loading, the model is ready for inference:
178
+ predictions = model_by_alias(data_input)
179
+ ```
180
+
181
+ For more advanced scenarios, if you only need to load the base `GotenNet` representation module from a local checkpoint (e.g., a checkpoint that only contains representation weights), you can use:
182
+
183
+ ```python
184
+ from gotennet.models.representation import GotenNet, GotenNetWrapper
185
+
186
+ # Example: Load a GotenNet representation from a local file
187
+ representation_checkpoint_path = "/path/to/your/local_model.ckpt"
188
+ gotennet_model = GotenNet.load_from_checkpoint(representation_checkpoint_path)
189
+ # or
190
+ gotennet_wrapped = GotenNetWrapper.load_from_checkpoint(representation_checkpoint_path)
191
+ ```
192
+
193
+ ### Training a Model
194
+
195
+ After installation, you can use the `train_gotennet` command:
196
+
197
+ ```bash
198
+ train_gotennet
199
+ ```
200
+
201
+ Or you can run the training script directly:
202
+
203
+ ```bash
204
+ python gotennet/scripts/train.py
205
+ ```
206
+
207
+ Both methods use Hydra for configuration. You can reproduce U0 target prediction on the QM9 dataset with the following command:
208
+
209
+ ```bash
210
+ train_gotennet experiment=qm9_u0.yaml
211
+ ```
212
+
213
+ ### Testing a Model
214
+
215
+ To evaluate a trained model, you can use the `test_gotennet` script. When you provide a checkpoint, the script can infer necessary configurations (like dataset and task details) directly from the checkpoint file. This script leverages the `GotenModel.from_pretrained` capabilities, allowing you to specify the model to test by its alias, a direct URL, or a local file path, handling automatic downloads.
216
+
217
+ Here's how you can use it:
218
+
219
+ ```bash
220
+ # Option 1: Test by model alias (e.g., QM9_small_homo)
221
+ # The script will automatically download the checkpoint and infer configurations.
222
+ test_gotennet checkpoint=QM9_small_homo
223
+
224
+ # Option 2: Test with a direct checkpoint URL
225
+ # The script will automatically download the checkpoint and infer configurations.
226
+ test_gotennet checkpoint=https://huggingface.co/sarpaykent/GotenNet/resolve/main/pretrained/qm9/small/gotennet_homo.ckpt
227
+
228
+ # Option 3: Test with a local checkpoint file path
229
+ test_gotennet checkpoint=/path/to/your/local_model.ckpt
230
+ ```
231
+
232
+ The script uses [Hydra](https://hydra.cc/) for any additional or overriding configurations if needed, but for straightforward evaluation of a checkpoint, only the `checkpoint` argument is typically required.
233
+
234
+ ### Configuration
235
+
236
+ The project uses [Hydra](https://hydra.cc/) for configuration management. Configuration files are located in the `configs/` directory.
237
+
238
+ Main configuration categories:
239
+ - `datamodule`: Dataset configurations (md17, qm9, etc.)
240
+ - `model`: Model configurations
241
+ - `trainer`: Training parameters
242
+ - `callbacks`: Callback configurations
243
+ - `logger`: Logging configurations
244
+
245
+ ## 🤝 Contributing
246
+
247
+ We welcome contributions to GotenNet! Please feel free to submit a Pull Request.
248
+
249
+
250
+ ## 📚 Citation
251
+
252
+ Please consider citing our work below if this project is helpful:
253
+
254
+
255
+ ```bibtex
256
+ @inproceedings{aykent2025gotennet,
257
+ author = {Aykent, Sarp and Xia, Tian},
258
+ booktitle = {The Thirteenth International Conference on LearningRepresentations},
259
+ year = {2025},
260
+ title = {{GotenNet: Rethinking Efficient 3D Equivariant Graph Neural Networks}},
261
+ url = {https://openreview.net/forum?id=5wxCQDtbMo},
262
+ howpublished = {https://openreview.net/forum?id=5wxCQDtbMo},
263
+ }
264
+ ```
265
+
266
+ ## 📄 License
267
+
268
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
269
+
270
+ ## Acknowledgements
271
+
272
+ GotenNet is proudly built on the innovative foundations provided by the projects below.
273
+ - [e3nn](https://github.com/e3nn/e3nn)
274
+ - [PyG](https://github.com/pyg-team/pytorch_geometric)
275
+ - [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning)