|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Literal |
|
|
|
from transformers.models.modernbert.configuration_modernbert import ModernBertConfig |
|
|
|
|
|
class ModChemBertConfig(ModernBertConfig): |
|
""" |
|
Configuration class for ModChemBert models. |
|
|
|
This configuration class extends ModernBertConfig with additional parameters specific to |
|
chemical molecule modeling and custom pooling strategies for classification/regression tasks. |
|
It accepts all arguments and keyword arguments from ModernBertConfig. |
|
|
|
Args: |
|
classifier_pooling (str, optional): Pooling strategy for sequence classification. |
|
Available options: |
|
- "cls": Use CLS token representation |
|
- "mean": Attention-weighted average pooling |
|
- "sum_mean": Sum all hidden states across layers, then mean pool over sequence (ChemLM approach) |
|
- "sum_sum": Sum all hidden states across layers, then sum pool over sequence |
|
- "mean_mean": Mean all hidden states across layers, then mean pool over sequence |
|
- "mean_sum": Mean all hidden states across layers, then sum pool over sequence |
|
- "max_cls": Element-wise max pooling over last k hidden states, then take CLS token |
|
- "cls_mha": Multi-head attention with CLS token as query and full sequence as keys/values |
|
- "max_seq_mha": Max pooling over last k states + multi-head attention with CLS as query |
|
- "max_seq_mean": Max pooling over last k hidden states, then mean pooling over sequence |
|
Defaults to "sum_mean". |
|
classifier_pooling_num_attention_heads (int, optional): Number of attention heads for multi-head attention |
|
pooling strategies (cls_mha, max_seq_mha). Defaults to 4. |
|
classifier_pooling_attention_dropout (float, optional): Dropout probability for multi-head attention |
|
pooling strategies (cls_mha, max_seq_mha). Defaults to 0.0. |
|
classifier_pooling_last_k (int, optional): Number of last hidden layers to use for max pooling |
|
strategies (max_cls, max_seq_mha, max_seq_mean). Defaults to 8. |
|
*args: Variable length argument list passed to ModernBertConfig. |
|
**kwargs: Arbitrary keyword arguments passed to ModernBertConfig. |
|
|
|
Note: |
|
This class inherits all configuration parameters from ModernBertConfig including |
|
hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, etc. |
|
""" |
|
|
|
model_type = "modchembert" |
|
|
|
def __init__( |
|
self, |
|
*args, |
|
classifier_pooling: Literal[ |
|
"cls", |
|
"mean", |
|
"sum_mean", |
|
"sum_sum", |
|
"mean_mean", |
|
"mean_sum", |
|
"max_cls", |
|
"cls_mha", |
|
"max_seq_mha", |
|
"max_seq_mean", |
|
] = "max_seq_mha", |
|
classifier_pooling_num_attention_heads: int = 4, |
|
classifier_pooling_attention_dropout: float = 0.0, |
|
classifier_pooling_last_k: int = 8, |
|
**kwargs, |
|
): |
|
|
|
super().__init__(*args, classifier_pooling="cls", **kwargs) |
|
|
|
self.classifier_pooling = classifier_pooling |
|
self.classifier_pooling_num_attention_heads = classifier_pooling_num_attention_heads |
|
self.classifier_pooling_attention_dropout = classifier_pooling_attention_dropout |
|
self.classifier_pooling_last_k = classifier_pooling_last_k |
|
|