File size: 2,860 Bytes
7a26e46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from typing import Literal

from transformers import PretrainedConfig

SPECIAL_TOKENS_DICT = {
    "PAD": 0,
    "MASK": 1,
    "CLS": 2,
    "SEP": 3,
    "PROT_EMB": 4,
    "END": 5,
}


class BacformerConfig(PretrainedConfig):
    """Configuration class to store the configuration of a `BacformerModel`."""

    model_type = "bacformer"

    def __init__(
        self,
        num_hidden_layers: int = 6,
        num_attention_heads: int = 8,
        hidden_size: int = 480,  # default esm2_t12_35M_UR50D embedding dim
        intermediate_size: int = 1280,
        hidden_dropout_prob: float = 0.1,
        attention_probs_dropout_prob: float = 0.1,
        max_position_embeddings: int = 6000,
        max_token_type_embeddings: int = 1000,
        layer_norm_eps: float = 1e-12,
        initializer_range: float = 0.02,
        pad_token_id: int = SPECIAL_TOKENS_DICT["PAD"],
        mask_token_id: int = SPECIAL_TOKENS_DICT["MASK"],
        prot_emb_token_id: int = SPECIAL_TOKENS_DICT["PROT_EMB"],
        end_token_id: int = SPECIAL_TOKENS_DICT["END"],
        num_special_tokens: int = len(SPECIAL_TOKENS_DICT),
        protein_clusters_vocab_size: int = 50001,  # equal to the nr of protein clusters + 1
        num_labels: int = 1,  # for downstream tasks
        is_causal_gm: bool = False,
        return_dict: bool = False,
        return_attn_weights: bool = False,
        alpha_contrastive_loss: float = 0.5,
        # only to use in the BacformerForGenomeClassification
        problem_type: Literal[
            "regression", "binary_classification", "single_label_classification", "multi_label_classification"
        ] = "single_label_classification",
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.max_token_type_embeddings = max_token_type_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.initializer_range = initializer_range
        self.pad_token_id = pad_token_id
        self.mask_token_id = mask_token_id
        self.prot_emb_token_id = prot_emb_token_id
        self.end_token_id = end_token_id
        self.num_special_tokens = num_special_tokens
        self.protein_clusters_vocab_size = protein_clusters_vocab_size
        self.num_labels = num_labels
        self.is_causal_gm = is_causal_gm
        self.return_dict = return_dict
        self.return_attn_weights = return_attn_weights
        self.problem_type = problem_type
        self.alpha_contrastive_loss = alpha_contrastive_loss