NovoMolGen_300M_SMILES_AtomWise / modeling_novomolgen.py
kmchiti's picture
Initial upload with README & teaser
a643f4e verified
import copy
import json
import os.path
import re
import shutil
import inspect
from typing import Optional, Union
import torch
import torch.nn.functional as F
from transformers import LlamaConfig
from transformers.loss.loss_utils import LOSS_MAPPING
from transformers.modeling_outputs import CausalLMOutput
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
from transformers.utils import (
SAFE_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.modeling_utils import unwrap_model, logger
from functools import partial
from safetensors.torch import load_file as safe_load_file
try:
from flash_attn.models.gpt import GPTLMHeadModel
except ImportError:
GPTLMHeadModel = None
try:
from flash_attn.models.llama import llama_config_to_gpt2_config, inv_remap_state_dict_hf_llama
except ImportError:
llama_config_to_gpt2_config = None
inv_remap_state_dict_hf_llama = None
def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None, **kwargs):
"""
code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
"""
# If not fp32, then we don't want to load directly to the GPU
mapped_device = "cpu" if dtype not in [torch.float32, None] else device
is_sharded = False
load_safe = False
# Try loading from HF hub instead of from local files
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
_raise_exceptions_for_missing_entries=False, **kwargs)
if resolved_archive_file is None:
resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
_raise_exceptions_for_missing_entries=False, **kwargs)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
raise EnvironmentError(f"Model name {model_name} was not found.")
if load_safe:
loader = partial(safe_load_file, device=mapped_device)
else:
loader = partial(torch.load, map_location=mapped_device)
if is_sharded:
# resolved_archive_file becomes a list of files that point to the different
# checkpoint shards in this case.
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
model_name, resolved_archive_file
)
state_dict = {}
for sharded_file in resolved_archive_file:
state_dict.update(loader(sharded_file))
else:
state_dict = loader(resolved_archive_file)
# Convert dtype before moving to GPU to save memory
if dtype is not None:
state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
return state_dict
class NovoMolGenConfig(LlamaConfig):
# model_type = "NovoMolGen"
def __init__(self,
use_flash_attn: bool = True,
fused_bias_fc: bool = True,
fused_mlp: bool = False,
fused_dropout_add_ln: bool = True,
residual_in_fp32: bool = True,
loss_type: str = 'ForCausalLM',
**kwargs
):
super().__init__(**kwargs)
self.use_flash_attn = use_flash_attn
self.fused_bias_fc = fused_bias_fc
self.fused_mlp = fused_mlp
self.fused_dropout_add_ln = fused_dropout_add_ln
self.residual_in_fp32 = residual_in_fp32
self.loss_type = loss_type
self.auto_map = {"AutoModelForCausalLM": "modeling_novomolgen.NovoMolGen"}
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
checkpoint_path: str = "",
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
**kwargs,
):
resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
os.path.join(checkpoint_path, "config.json"),
_raise_exceptions_for_missing_entries=False, force_download=force_download)
if resolved_archive_config_file is not None:
with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
text = reader.read()
config_dict = json.loads(text)
else:
raise EnvironmentError(f"config for {pretrained_model_name_or_path} was not found.")
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
print(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class NovoMolGen(GPTLMHeadModel):
def __init__(
self,
config: NovoMolGenConfig,
mol_type: str = "SMILES",
):
self.base_config = config
self.mol_type = mol_type
config = llama_config_to_gpt2_config(config)
config.use_flash_attn = self.base_config.use_flash_attn
config.fused_bias_fc = self.base_config.fused_bias_fc
config.fused_mlp = self.base_config.fused_mlp
config.fused_dropout_add_ln = self.base_config.fused_dropout_add_ln
config.residual_in_fp32 = self.base_config.residual_in_fp32
GPTLMHeadModel.__init__(self, config)
# TODO: here we ignore attention_mask to make it compatible with HF trainer. The MHA in flash-attention should
# be reimplement and integrate attention_mask like here:
# https://github.com/huggingface/transformers/blob/0864dd3beb238b7bec3528a3d1d6c17a28f51a51/src/transformers/models/llama/modeling_llama.py#L536
def forward(self, input_ids, attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None,
position_ids=None, inference_params=None, num_last_tokens=0, **loss_kwargs):
"""
input_ids: (batch, seqlen) int tensor
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
num_last_tokens: if > 0, only return the logits for the last n tokens
"""
assert (
input_ids.ndim == 2
), f"Expected `input_ids` to have shape [b, slen], but got shape {input_ids.shape}"
b, slen = input_ids.shape
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
)
if inference_params is not None:
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
if self.output_scale != 1.0:
hidden_states = hidden_states * self.output_scale
if not self.norm_head:
lm_logits = self.lm_head(hidden_states)
else:
lm_head_weight = F.normalize(self.lm_head.weight)
# if isinstance(self.lm_head, ColumnParallelLinear) and self.lm_head.sequence_parallel:
# hidden_states = all_gather(hidden_states, self.lm_head.process_group)
lm_logits = F.linear(hidden_states, lm_head_weight, bias=self.lm_head.bias)
# During inference, we want the full logit for sampling
# if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None:
# lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group)
# lm_logits = rearrange(lm_logits, "(n b) ... d -> b ... (n d)", b=b)
loss = None
if labels is not None:
loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.base_config.vocab_size,
**loss_kwargs)
return CausalLMOutput(
loss=loss,
logits=lm_logits,
hidden_states=hidden_states
)
@property
def loss_function(self):
if getattr(self.base_config, "loss_type", None) is not None:
loss_type = self.base_config.loss_type
else:
loss_type = self.__class__.__name__
if loss_type not in LOSS_MAPPING:
loss_groups = f"({'|'.join(LOSS_MAPPING)})"
loss_type = re.findall(loss_groups, self.__class__.__name__)
if len(loss_type) > 0:
loss_type = loss_type[0]
else:
loss_type = None
if loss_type is None or loss_type not in LOSS_MAPPING and getattr(self.base_config, "loss_type",
None) is not None:
print(
f"`loss_type={loss_type}` was set in the base_config but it is unrecognised."
f"Using the default loss: `ForCausalLMLoss`."
)
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
safe_serialization: bool = False,
**kwargs,
):
if safe_serialization:
raise ImportError("`safe_serialization` is not implemented yet`.")
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
os.makedirs(save_directory, exist_ok=True)
# Save the config
if is_main_process:
self.base_config.save_pretrained(save_directory)
# Save the model
if state_dict is None:
# Only save the model itself if we are using distributed training
model_to_save = unwrap_model(self)
state_dict = model_to_save.state_dict()
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
torch.save(state_dict, os.path.join(save_directory, weights_name))
# find the file where NovoMolGen is defined
src = inspect.getsourcefile(type(self))
if src:
dst = os.path.join(save_directory, os.path.basename(src))
shutil.copy(src, dst)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path,
checkpoint_path: str = "",
config: Optional[Union[NovoMolGenConfig, str, os.PathLike]] = None,
**kwargs,
):
if config is None:
config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
model = cls(config)
if os.path.exists(pretrained_model_name_or_path):
state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
else:
state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
model.load_state_dict(state_dict)
return model
def sample(
self,
tokenizer,
batch_size: int = 4,
max_length: int = 64,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
device: torch.device = torch.device("cuda"),
):
"""
Generate a batch of sequences from the model.
Returns a dictionary with up to three keys:
{
"<mol_type>": <list of raw sequences in that moltype>,
"sequences": <torch.LongTensor of valid token IDs>
}
"""
input_ids = tokenizer.encode("", return_tensors="pt").to(device)
# Repeat the prompt for the desired batch size
input_ids = input_ids.repeat_interleave(batch_size, dim=0)
# If the tokenizer includes an EOS token for an empty prompt, we remove it.
if input_ids.shape[1] > 1:
input_ids = input_ids[:, :-1]
generation_output = self.generate(
input_ids,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
eos_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
)
sequences = self._filter_tokens_after_eos(
generation_output.sequences, eos_id=tokenizer.eos_token_id
)
decoded_strings = tokenizer.batch_decode(sequences, skip_special_tokens=True)
decoded_strings = [s.replace(" ", "") for s in decoded_strings]
result = {
self.mol_type: decoded_strings,
"sequences": sequences,
}
return result
@staticmethod
def _filter_tokens_after_eos(sequences, eos_id):
output = copy.deepcopy(sequences)
for i in range(sequences.size(0)):
row = sequences[i]
eos_position = (row == eos_id).nonzero()
if eos_position.numel() > 0:
eos_position = eos_position[0, 0].item() # Get the index of the first occurrence
output[i, eos_position + 1:] = eos_id
return output
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
# HF’s GenerationMixin would normally do more, but for a basic LM this usually suffices:
return {"input_ids": input_ids, "attention_mask": attention_mask}