Initial upload with README & teaser
Browse files- .gitattributes +1 -0
- README.md +14 -3
- assets/NovoMolGen.png +3 -0
- modeling_novomolgen.py +6 -6
- pytorch_model.bin +2 -2
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/NovoMolGen.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -14,7 +14,12 @@ pipeline_tag: text-generation
|
|
14 |
|
15 |
# NovoMolGen
|
16 |
|
17 |
-
NovoMolGen is a family of molecular foundation models trained on
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
## How to load
|
20 |
|
@@ -24,9 +29,14 @@ tokenizer = AutoTokenizer.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BP
|
|
24 |
model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BPE", trust_remote_code=True)
|
25 |
```
|
26 |
|
27 |
-
##
|
28 |
|
29 |
```python
|
|
|
|
|
|
|
|
|
|
|
30 |
outputs = model.sample(tokenizer=tokenizer, batch_size=4)
|
31 |
print(outputs['SMILES'])
|
32 |
```
|
@@ -36,7 +46,8 @@ print(outputs['SMILES'])
|
|
36 |
```bibtex
|
37 |
@article{chitsaz2024novomolgen,
|
38 |
title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
|
39 |
-
author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and
|
|
|
40 |
journal={arXiv preprint},
|
41 |
year={2025},
|
42 |
}
|
|
|
14 |
|
15 |
# NovoMolGen
|
16 |
|
17 |
+
NovoMolGen is a family of molecular foundation models trained on
|
18 |
+
1.5 billion ZINC-22 molecules with Llama architectures and FlashAttention.
|
19 |
+
It achieves state-of-the-art performance on both unconstrained and
|
20 |
+
goal-directed molecule generation tasks.
|
21 |
+
|
22 |
+
<img src="assets/NovoMolGen.png" width="900"/>
|
23 |
|
24 |
## How to load
|
25 |
|
|
|
29 |
model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BPE", trust_remote_code=True)
|
30 |
```
|
31 |
|
32 |
+
## Quick-start (FlashAttention + bf16)
|
33 |
|
34 |
```python
|
35 |
+
from accelerate import Accelerator
|
36 |
+
|
37 |
+
acc = Accelerator(mixed_precision='bf16')
|
38 |
+
model = acc.prepare(model)
|
39 |
+
|
40 |
outputs = model.sample(tokenizer=tokenizer, batch_size=4)
|
41 |
print(outputs['SMILES'])
|
42 |
```
|
|
|
46 |
```bibtex
|
47 |
@article{chitsaz2024novomolgen,
|
48 |
title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
|
49 |
+
author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and
|
50 |
+
Bhatt, Nirav Pravinbhai and Chandar, Sarath},
|
51 |
journal={arXiv preprint},
|
52 |
year={2025},
|
53 |
}
|
assets/NovoMolGen.png
ADDED
![]() |
Git LFS Details
|
modeling_novomolgen.py
CHANGED
@@ -33,7 +33,7 @@ except ImportError:
|
|
33 |
inv_remap_state_dict_hf_llama = None
|
34 |
|
35 |
|
36 |
-
def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None):
|
37 |
"""
|
38 |
code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
|
39 |
"""
|
@@ -45,10 +45,10 @@ def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=Non
|
|
45 |
|
46 |
# Try loading from HF hub instead of from local files
|
47 |
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
|
48 |
-
_raise_exceptions_for_missing_entries=False)
|
49 |
if resolved_archive_file is None:
|
50 |
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
|
51 |
-
_raise_exceptions_for_missing_entries=False)
|
52 |
if resolved_archive_file is not None:
|
53 |
is_sharded = True
|
54 |
|
@@ -115,7 +115,7 @@ class NovoMolGenConfig(LlamaConfig):
|
|
115 |
|
116 |
resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
|
117 |
os.path.join(checkpoint_path, "config.json"),
|
118 |
-
_raise_exceptions_for_missing_entries=False)
|
119 |
|
120 |
if resolved_archive_config_file is not None:
|
121 |
with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
|
@@ -266,13 +266,13 @@ class NovoMolGen(GPTLMHeadModel):
|
|
266 |
**kwargs,
|
267 |
):
|
268 |
if config is None:
|
269 |
-
config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
|
270 |
model = cls(config)
|
271 |
|
272 |
if os.path.exists(pretrained_model_name_or_path):
|
273 |
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
|
274 |
else:
|
275 |
-
state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
|
276 |
model.load_state_dict(state_dict)
|
277 |
return model
|
278 |
|
|
|
33 |
inv_remap_state_dict_hf_llama = None
|
34 |
|
35 |
|
36 |
+
def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None, **kwargs):
|
37 |
"""
|
38 |
code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
|
39 |
"""
|
|
|
45 |
|
46 |
# Try loading from HF hub instead of from local files
|
47 |
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
|
48 |
+
_raise_exceptions_for_missing_entries=False, **kwargs)
|
49 |
if resolved_archive_file is None:
|
50 |
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
|
51 |
+
_raise_exceptions_for_missing_entries=False, **kwargs)
|
52 |
if resolved_archive_file is not None:
|
53 |
is_sharded = True
|
54 |
|
|
|
115 |
|
116 |
resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
|
117 |
os.path.join(checkpoint_path, "config.json"),
|
118 |
+
_raise_exceptions_for_missing_entries=False, force_download=force_download)
|
119 |
|
120 |
if resolved_archive_config_file is not None:
|
121 |
with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
|
|
|
266 |
**kwargs,
|
267 |
):
|
268 |
if config is None:
|
269 |
+
config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
|
270 |
model = cls(config)
|
271 |
|
272 |
if os.path.exists(pretrained_model_name_or_path):
|
273 |
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
|
274 |
else:
|
275 |
+
state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
|
276 |
model.load_state_dict(state_dict)
|
277 |
return model
|
278 |
|
pytorch_model.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2c3c25c79d8b9317f7bebe4ff9029ec9524ea66cb426569cd2a7fd8ca8c2e37
|
3 |
+
size 1211326870
|