ModChemBERT / configuration_modchembert.py
eacortes's picture
Upload 19 files
6a7a58f verified
# Copyright 2025 Emmanuel Cortes, All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
from transformers.models.modernbert.configuration_modernbert import ModernBertConfig
class ModChemBertConfig(ModernBertConfig):
"""
Configuration class for ModChemBert models.
This configuration class extends ModernBertConfig with additional parameters specific to
chemical molecule modeling and custom pooling strategies for classification/regression tasks.
It accepts all arguments and keyword arguments from ModernBertConfig.
Args:
classifier_pooling (str, optional): Pooling strategy for sequence classification.
Available options:
- "cls": Use CLS token representation
- "mean": Attention-weighted average pooling
- "sum_mean": Sum all hidden states across layers, then mean pool over sequence (ChemLM approach)
- "sum_sum": Sum all hidden states across layers, then sum pool over sequence
- "mean_mean": Mean all hidden states across layers, then mean pool over sequence
- "mean_sum": Mean all hidden states across layers, then sum pool over sequence
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence
Defaults to "sum_mean".
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention
pooling strategies (cls_mha, max_seq_mha). Defaults to 4.
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention
pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0.
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling
strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8.
*args: Variable length argument list passed to ModernBertConfig.
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig.
Note:
This class inherits all configuration parameters from ModernBertConfig including
hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, etc.
"""
model_type = "modchembert"
def __init__(
self,
*args,
classifier_pooling: Literal[
"cls",
"mean",
"sum_mean",
"sum_sum",
"mean_mean",
"mean_sum",
"max_cls",
"cls_mha",
"max_seq_mha",
"max_seq_mean",
] = "max_seq_mha",
classifier_pooling_num_attention_heads: int = 4,
classifier_pooling_attention_dropout: float = 0.0,
classifier_pooling_last_k: int = 8,
**kwargs,
):
# Pass classifier_pooling="cls" to circumvent ValueError in ModernBertConfig init
super().__init__(*args, classifier_pooling="cls", **kwargs)
# Override with custom value
self.classifier_pooling = classifier_pooling
self.classifier_pooling_num_attention_heads = classifier_pooling_num_attention_heads
self.classifier_pooling_attention_dropout = classifier_pooling_attention_dropout
self.classifier_pooling_last_k = classifier_pooling_last_k