kmchiti commited on
Commit
8097b37
·
verified ·
1 Parent(s): 37077c6

Initial upload with README & teaser

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/NovoMolGen.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -14,7 +14,12 @@ pipeline_tag: text-generation
14
 
15
  # NovoMolGen
16
 
17
- NovoMolGen is a family of molecular foundation models trained on 1.5 billion ZINC‑22 molecules using Llama architectures and FlashAttention. It achieves state‑of‑the‑art performance on both unconstrained and goal‑directed molecule generation tasks.
 
 
 
 
 
18
 
19
  ## How to load
20
 
@@ -24,9 +29,14 @@ tokenizer = AutoTokenizer.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BP
24
  model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BPE", trust_remote_code=True)
25
  ```
26
 
27
- ## Quickstart
28
 
29
  ```python
 
 
 
 
 
30
  outputs = model.sample(tokenizer=tokenizer, batch_size=4)
31
  print(outputs['SMILES'])
32
  ```
@@ -36,7 +46,8 @@ print(outputs['SMILES'])
36
  ```bibtex
37
  @article{chitsaz2024novomolgen,
38
  title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
39
- author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and Bhatt, Nirav Pravinbhai and Chandar, Sarath},
 
40
  journal={arXiv preprint},
41
  year={2025},
42
  }
 
14
 
15
  # NovoMolGen
16
 
17
+ NovoMolGen is a family of molecular foundation models trained on
18
+ 1.5 billion ZINC-22 molecules with Llama architectures and FlashAttention.
19
+ It achieves state-of-the-art performance on both unconstrained and
20
+ goal-directed molecule generation tasks.
21
+
22
+ <img src="assets/NovoMolGen.png" width="900"/>
23
 
24
  ## How to load
25
 
 
29
  model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_300M_SMILES_BPE", trust_remote_code=True)
30
  ```
31
 
32
+ ## Quick-start (FlashAttention + bf16)
33
 
34
  ```python
35
+ from accelerate import Accelerator
36
+
37
+ acc = Accelerator(mixed_precision='bf16')
38
+ model = acc.prepare(model)
39
+
40
  outputs = model.sample(tokenizer=tokenizer, batch_size=4)
41
  print(outputs['SMILES'])
42
  ```
 
46
  ```bibtex
47
  @article{chitsaz2024novomolgen,
48
  title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
49
+ author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and
50
+ Bhatt, Nirav Pravinbhai and Chandar, Sarath},
51
  journal={arXiv preprint},
52
  year={2025},
53
  }
assets/NovoMolGen.png ADDED

Git LFS Details

  • SHA256: 9db01aa9afa3b39dfe789009590118265a2dddb5ebf057e878ecfc9dd9328ce8
  • Pointer size: 132 Bytes
  • Size of remote file: 5.08 MB
modeling_novomolgen.py CHANGED
@@ -33,7 +33,7 @@ except ImportError:
33
  inv_remap_state_dict_hf_llama = None
34
 
35
 
36
- def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None):
37
  """
38
  code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
39
  """
@@ -45,10 +45,10 @@ def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=Non
45
 
46
  # Try loading from HF hub instead of from local files
47
  resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
48
- _raise_exceptions_for_missing_entries=False)
49
  if resolved_archive_file is None:
50
  resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
51
- _raise_exceptions_for_missing_entries=False)
52
  if resolved_archive_file is not None:
53
  is_sharded = True
54
 
@@ -115,7 +115,7 @@ class NovoMolGenConfig(LlamaConfig):
115
 
116
  resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
117
  os.path.join(checkpoint_path, "config.json"),
118
- _raise_exceptions_for_missing_entries=False)
119
 
120
  if resolved_archive_config_file is not None:
121
  with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
@@ -266,13 +266,13 @@ class NovoMolGen(GPTLMHeadModel):
266
  **kwargs,
267
  ):
268
  if config is None:
269
- config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
270
  model = cls(config)
271
 
272
  if os.path.exists(pretrained_model_name_or_path):
273
  state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
274
  else:
275
- state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
276
  model.load_state_dict(state_dict)
277
  return model
278
 
 
33
  inv_remap_state_dict_hf_llama = None
34
 
35
 
36
+ def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None, **kwargs):
37
  """
38
  code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
39
  """
 
45
 
46
  # Try loading from HF hub instead of from local files
47
  resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
48
+ _raise_exceptions_for_missing_entries=False, **kwargs)
49
  if resolved_archive_file is None:
50
  resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
51
+ _raise_exceptions_for_missing_entries=False, **kwargs)
52
  if resolved_archive_file is not None:
53
  is_sharded = True
54
 
 
115
 
116
  resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
117
  os.path.join(checkpoint_path, "config.json"),
118
+ _raise_exceptions_for_missing_entries=False, force_download=force_download)
119
 
120
  if resolved_archive_config_file is not None:
121
  with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
 
266
  **kwargs,
267
  ):
268
  if config is None:
269
+ config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
270
  model = cls(config)
271
 
272
  if os.path.exists(pretrained_model_name_or_path):
273
  state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
274
  else:
275
+ state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
276
  model.load_state_dict(state_dict)
277
  return model
278
 
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3bbb07cc62cec4185d2b79ef6e60d1dec5627ad138283fa7063d4c2371959fd9
3
- size 1211329558
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2c3c25c79d8b9317f7bebe4ff9029ec9524ea66cb426569cd2a7fd8ca8c2e37
3
+ size 1211326870