|
import math |
|
from collections import OrderedDict |
|
from dataclasses import dataclass |
|
from typing import Literal, Optional, Union |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn.functional import ( |
|
binary_cross_entropy_with_logits, |
|
cross_entropy, |
|
gelu, |
|
mse_loss, |
|
scaled_dot_product_attention, |
|
softmax, |
|
) |
|
from transformers import PreTrainedModel |
|
from transformers.utils import ModelOutput |
|
|
|
from .configuration_bacformer import SPECIAL_TOKENS_DICT, BacformerConfig |
|
from .utils_bacformer import compute_contrastive_loss, create_4d_from_2d_attn_mask, top_k_filtering, top_p_filtering |
|
|
|
|
|
@dataclass |
|
class BacformerModelOutput(ModelOutput): |
|
"""Base class for outputs of the Bacformer model.""" |
|
|
|
loss: torch.FloatTensor | None = None |
|
logits: torch.FloatTensor = None |
|
last_hidden_state: torch.FloatTensor | None = None |
|
attentions: Union[torch.FloatTensor, None] = None |
|
pooler_output: torch.FloatTensor | None = None |
|
|
|
|
|
|
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
|
"""Reshape the rotary embeddings for broadcasting.""" |
|
ndim = x.ndim |
|
assert 0 <= 1 < ndim |
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) |
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
|
return freqs_cis.view(*shape) |
|
|
|
|
|
|
|
def apply_rotary_emb( |
|
xq: torch.Tensor, |
|
xk: torch.Tensor, |
|
freqs_cos: torch.Tensor, |
|
freqs_sin: torch.Tensor, |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Apply rotary embeddings to the query and key tensors.""" |
|
|
|
xq_r, xq_i = xq.float().reshape(*xq.shape[:-1], -1, 2).unbind(-1) |
|
xk_r, xk_i = xk.float().reshape(*xk.shape[:-1], -1, 2).unbind(-1) |
|
|
|
|
|
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) |
|
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) |
|
|
|
|
|
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin |
|
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos |
|
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin |
|
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos |
|
|
|
|
|
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) |
|
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) |
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
|
"""Precompute the freqs cis for rotary embeddings.""" |
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
|
t = torch.arange(end, device=freqs.device) |
|
freqs = torch.outer(t, freqs).float() |
|
|
|
freqs_cos = torch.cos(freqs) |
|
freqs_sin = torch.sin(freqs) |
|
return freqs_cos, freqs_sin |
|
|
|
|
|
def scaled_dot_product_attention_w_attn_weights( |
|
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None |
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""PyTorch Native implementation, modified to return attention weights.""" |
|
L, S = query.size(-2), key.size(-2) |
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale |
|
attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) |
|
if is_causal: |
|
assert attn_mask is None |
|
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) |
|
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) |
|
attn_bias.to(query.dtype) |
|
|
|
if attn_mask is not None: |
|
if attn_mask.dtype == torch.bool: |
|
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) |
|
else: |
|
attn_bias += attn_mask |
|
attn_weight = query @ key.transpose(-2, -1) * scale_factor |
|
attn_weight += attn_bias |
|
attn_weight = torch.softmax(attn_weight, dim=-1) |
|
attn_weight = torch.dropout(attn_weight, dropout_p, train=True) |
|
attn_output = attn_weight @ value |
|
return attn_output, attn_weight |
|
|
|
|
|
class RotarySelfAttention(nn.Module): |
|
"""Rotary self-attention module.""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
num_heads: int, |
|
dropout: float = 0.1, |
|
): |
|
super().__init__() |
|
self.embed_dim = embed_dim |
|
self.num_heads = num_heads |
|
self.dim_head = embed_dim // num_heads |
|
self.dropout_rate = dropout |
|
|
|
self.q = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.k = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.v = nn.Linear(embed_dim, embed_dim, bias=False) |
|
self.att_proj_linear = nn.Linear(embed_dim, embed_dim) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
attn_mask: torch.Tensor, |
|
freqs_cos: torch.Tensor, |
|
freqs_sin: torch.Tensor, |
|
is_causal: bool = False, |
|
return_attn_weights: bool = False, |
|
): |
|
"""Forward pass for the rotary self-attention module.""" |
|
batch_size, seq_len, _ = x.shape |
|
xq, xk, xv = self.q(x), self.k(x), self.v(x) |
|
|
|
xq = xq.view(batch_size, seq_len, self.num_heads, self.dim_head) |
|
xk = xk.view(batch_size, seq_len, self.num_heads, self.dim_head) |
|
xv = xv.view(batch_size, seq_len, self.num_heads, self.dim_head) |
|
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) |
|
|
|
|
|
xq = xq.transpose(1, 2) |
|
xk = xk.transpose(1, 2) |
|
xv = xv.transpose(1, 2) |
|
|
|
attn_weights = None |
|
if return_attn_weights: |
|
att, attn_weights = scaled_dot_product_attention_w_attn_weights( |
|
query=xq, |
|
key=xk, |
|
value=xv, |
|
attn_mask=attn_mask, |
|
dropout_p=self.dropout_rate if self.training else 0.0, |
|
is_causal=is_causal, |
|
) |
|
else: |
|
att = scaled_dot_product_attention( |
|
query=xq, |
|
key=xk, |
|
value=xv, |
|
attn_mask=attn_mask, |
|
dropout_p=self.dropout_rate if self.training else 0.0, |
|
is_causal=is_causal, |
|
) |
|
|
|
out = att.transpose(1, 2).contiguous() |
|
out = out.view(batch_size, seq_len, self.num_heads * self.dim_head) |
|
|
|
return self.att_proj_linear(out), attn_weights |
|
|
|
|
|
class BacformerTransformerLayer(nn.Module): |
|
"""Own implementation of transformer layer which uses pytorch native MHA but returns attention weights""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
num_attention_heads: int, |
|
dropout: float = 0.1, |
|
activation: Literal["gelu", "relu"] = "gelu", |
|
): |
|
super().__init__() |
|
self.self_mha = RotarySelfAttention( |
|
embed_dim=hidden_size, |
|
num_heads=num_attention_heads, |
|
dropout=dropout, |
|
) |
|
|
|
self.fc1 = nn.Linear(hidden_size, intermediate_size) |
|
self.fc2 = nn.Linear(intermediate_size, hidden_size) |
|
self.activation = nn.GELU() if activation == "gelu" else nn.ReLU() |
|
self.norm1 = nn.LayerNorm(hidden_size) |
|
self.norm2 = nn.LayerNorm(hidden_size) |
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
self.dropout3 = nn.Dropout(dropout) |
|
|
|
def forward( |
|
self, |
|
hidden_state: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
freqs_cos: torch.Tensor = None, |
|
freqs_sin: torch.Tensor = None, |
|
return_attn_weights: bool = False, |
|
is_causal: bool = False, |
|
) -> tuple[torch.Tensor, torch.Tensor | None]: |
|
"""Forward pass""" |
|
attn_outputs, attn_weights = self.self_mha( |
|
hidden_state, |
|
attn_mask=attention_mask, |
|
freqs_cos=freqs_cos, |
|
freqs_sin=freqs_sin, |
|
return_attn_weights=return_attn_weights, |
|
is_causal=is_causal, |
|
) |
|
x = self.norm1(hidden_state + self.dropout1(attn_outputs)) |
|
ff_output = self.fc2(self.dropout2(self.activation(self.fc1(x)))) |
|
x = self.norm2(x + self.dropout3(ff_output)) |
|
return x, attn_weights |
|
|
|
|
|
class BacformerTransformerEncoder(nn.Module): |
|
"""Own implementation of Transformer which return attention weights""" |
|
|
|
def __init__( |
|
self, |
|
num_hidden_layers: int, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
num_attention_heads: int, |
|
dropout: float = 0.1, |
|
activation: Literal["gelu", "relu"] = "gelu", |
|
): |
|
super().__init__() |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
BacformerTransformerLayer( |
|
hidden_size=hidden_size, |
|
intermediate_size=intermediate_size, |
|
num_attention_heads=num_attention_heads, |
|
dropout=dropout, |
|
activation=activation, |
|
) |
|
for _ in range(num_hidden_layers) |
|
] |
|
) |
|
self.gradient_checkpointing = False |
|
|
|
def forward( |
|
self, |
|
hidden_state: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
freqs_cos: torch.Tensor = None, |
|
freqs_sin: torch.Tensor = None, |
|
return_attn_weights: bool = False, |
|
is_causal: bool = False, |
|
) -> tuple[torch.Tensor, list[torch.Tensor | None]]: |
|
"""Forward pass""" |
|
attn_weights_arr = [] |
|
for layer in self.layers: |
|
if self.gradient_checkpointing and self.training: |
|
hidden_state, attn_weights = self._gradient_checkpointing_func( |
|
layer.__call__, |
|
hidden_state, |
|
attention_mask, |
|
freqs_cos, |
|
freqs_sin, |
|
return_attn_weights, |
|
is_causal, |
|
) |
|
else: |
|
hidden_state, attn_weights = layer( |
|
hidden_state=hidden_state, |
|
attention_mask=attention_mask, |
|
freqs_cos=freqs_cos, |
|
freqs_sin=freqs_sin, |
|
return_attn_weights=return_attn_weights, |
|
is_causal=is_causal, |
|
) |
|
|
|
attn_weights_arr.append(attn_weights) |
|
return hidden_state, attn_weights_arr |
|
|
|
|
|
class BacformerEmbeddings(nn.Module): |
|
"""Construct the protein embeddings from protein sequence, position embeddings and sequence type embeddings.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
self.linear = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
|
self.token_type_embeddings = nn.Embedding( |
|
num_embeddings=config.max_token_type_embeddings + 1, |
|
embedding_dim=config.hidden_size, |
|
padding_idx=config.max_token_type_embeddings, |
|
) |
|
|
|
self.special_tokens_embeddings = nn.Embedding( |
|
num_embeddings=config.num_special_tokens, |
|
embedding_dim=config.hidden_size, |
|
) |
|
self.prot_emb_token_id = config.prot_emb_token_id |
|
self.pad_token_id = config.pad_token_id |
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor = None, |
|
special_tokens_mask: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
labels: torch.Tensor = None, |
|
property_ids: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for protein embeddings.""" |
|
bs, seq_length, dim = protein_embeddings.shape |
|
|
|
|
|
protein_embeddings = self.linear(protein_embeddings) |
|
protein_embeddings = torch.where( |
|
special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id, |
|
protein_embeddings, |
|
self.special_tokens_embeddings(special_tokens_mask), |
|
) |
|
|
|
if token_type_ids is not None: |
|
protein_embeddings += self.token_type_embeddings(token_type_ids) |
|
|
|
protein_embeddings = self.LayerNorm(protein_embeddings) |
|
protein_embeddings = self.dropout(protein_embeddings) |
|
return protein_embeddings |
|
|
|
|
|
class BacformerProteinFamilyEmbeddings(nn.Module): |
|
"""Construct the protein embeddings from protein family tokens, special tokens and sequence type embeddings.""" |
|
|
|
def __init__( |
|
self, |
|
config, |
|
protein_family_embeddings: torch.Tensor = None, |
|
token_type_embeddings: torch.Tensor = None, |
|
special_tokens_embeddings: torch.Tensor = None, |
|
n_conditional_properties: int = None, |
|
): |
|
super().__init__() |
|
self.config = config |
|
|
|
if protein_family_embeddings is not None: |
|
self.protein_family_embeddings = nn.Embedding.from_pretrained( |
|
protein_family_embeddings, |
|
freeze=False, |
|
padding_idx=config.pad_token_id, |
|
) |
|
else: |
|
self.protein_family_embeddings = nn.Embedding( |
|
num_embeddings=config.protein_clusters_vocab_size + 1, |
|
embedding_dim=config.hidden_size, |
|
padding_idx=config.pad_token_id, |
|
) |
|
|
|
if token_type_embeddings is not None: |
|
self.token_type_embeddings = nn.Embedding.from_pretrained( |
|
token_type_embeddings, |
|
freeze=False, |
|
padding_idx=config.max_token_type_embeddings, |
|
) |
|
else: |
|
self.token_type_embeddings = nn.Embedding( |
|
num_embeddings=config.max_token_type_embeddings + 1, |
|
embedding_dim=config.hidden_size, |
|
padding_idx=config.max_token_type_embeddings, |
|
) |
|
|
|
if special_tokens_embeddings is not None: |
|
self.special_tokens_embeddings = nn.Embedding.from_pretrained( |
|
special_tokens_embeddings, |
|
freeze=False, |
|
padding_idx=config.pad_token_id, |
|
) |
|
else: |
|
self.special_tokens_embeddings = nn.Embedding( |
|
num_embeddings=config.num_special_tokens, |
|
embedding_dim=config.hidden_size, |
|
padding_idx=config.pad_token_id, |
|
) |
|
|
|
|
|
if n_conditional_properties is not None: |
|
self.conditional_properties_layer = nn.Embedding(n_conditional_properties, config.hidden_size) |
|
|
|
self.prot_emb_token_id = config.prot_emb_token_id |
|
self.pad_token_id = config.pad_token_id |
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor = None, |
|
special_tokens_mask: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
labels: torch.Tensor = None, |
|
property_ids: torch.Tensor = None, |
|
) -> torch.Tensor: |
|
"""Forward pass for protein embeddings.""" |
|
|
|
|
|
labels[labels == -100] = self.pad_token_id |
|
protein_embeddings = self.protein_family_embeddings(labels) |
|
|
|
bs, seq_length, dim = protein_embeddings.shape |
|
protein_embeddings = torch.where( |
|
special_tokens_mask.unsqueeze(-1).repeat(1, 1, dim) == self.prot_emb_token_id, |
|
protein_embeddings, |
|
self.special_tokens_embeddings(special_tokens_mask), |
|
) |
|
|
|
if token_type_ids is not None: |
|
protein_embeddings += self.token_type_embeddings(token_type_ids) |
|
|
|
if property_ids is not None: |
|
|
|
property_embedding = self.conditional_properties_layer(property_ids).unsqueeze(1) |
|
|
|
|
|
protein_embeddings = torch.cat( |
|
[ |
|
protein_embeddings[:, :1, :], |
|
property_embedding, |
|
protein_embeddings[:, 1:, :], |
|
], |
|
dim=1, |
|
) |
|
|
|
protein_embeddings = self.LayerNorm(protein_embeddings) |
|
protein_embeddings = self.dropout(protein_embeddings) |
|
return protein_embeddings |
|
|
|
|
|
class BacformerEncoder(nn.Module): |
|
"""Bacformer encoder model""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
self.encoder = BacformerTransformerEncoder( |
|
num_hidden_layers=config.num_hidden_layers, |
|
hidden_size=config.hidden_size, |
|
num_attention_heads=config.num_attention_heads, |
|
intermediate_size=config.intermediate_size, |
|
activation="gelu", |
|
dropout=config.attention_probs_dropout_prob, |
|
) |
|
|
|
|
|
|
|
|
|
freqs_cos, freqs_sin = precompute_freqs_cis( |
|
config.hidden_size // config.num_attention_heads, int(config.max_position_embeddings * 1.5) |
|
) |
|
self.register_buffer("freqs_cos", freqs_cos, persistent=False) |
|
self.register_buffer("freqs_sin", freqs_sin, persistent=False) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: Union[bool, None] = None, |
|
is_causal: bool = False, |
|
) -> tuple[torch.Tensor, list[torch.Tensor | None]]: |
|
"""Pass the input through the encoder layers in turn. |
|
|
|
Args: |
|
hidden_states: hidden states from the BacformerEmbeddings layer |
|
attention_mask: mask for the attention in the transformer |
|
""" |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
bs, seq_len, _ = hidden_states.shape |
|
last_hidden_state, attn_weights = self.encoder( |
|
hidden_state=hidden_states, |
|
attention_mask=attention_mask, |
|
freqs_cos=self.freqs_cos[:seq_len, :], |
|
freqs_sin=self.freqs_sin[:seq_len, :], |
|
return_attn_weights=return_attn_weights, |
|
is_causal=is_causal, |
|
) |
|
return last_hidden_state, attn_weights |
|
|
|
|
|
class BacformerPreTrainedModel(PreTrainedModel): |
|
"""An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.""" |
|
|
|
config_class = BacformerConfig |
|
base_model_prefix = "bacformer" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = ["BacformerEmbeddings", "BacformerTransformerLayer"] |
|
|
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class BacformerModel(BacformerPreTrainedModel): |
|
"""Bacformer model.""" |
|
|
|
def __init__(self, config: BacformerConfig, add_pooling_layer: bool = False): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.embeddings = BacformerEmbeddings(config) |
|
self.encoder = BacformerEncoder(config) |
|
|
|
self.pooler = BacformerPooler(config) if add_pooling_layer else None |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor = None, |
|
special_tokens_mask: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
labels: torch.Tensor = None, |
|
property_ids: torch.Tensor = None, |
|
return_attn_weights: bool = False, |
|
return_dict: Union[bool, None] = None, |
|
is_causal: bool = False, |
|
) -> Optional[BacformerModelOutput]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
protein_embeddings = self.embeddings( |
|
protein_embeddings=protein_embeddings, |
|
labels=labels, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
property_ids=property_ids, |
|
) |
|
|
|
|
|
if attention_mask is not None and not is_causal: |
|
attention_mask = create_4d_from_2d_attn_mask( |
|
attn_mask=attention_mask, num_attn_heads=self.config.num_attention_heads |
|
).bool() |
|
|
|
last_hidden_state, attentions = self.encoder( |
|
hidden_states=protein_embeddings, |
|
attention_mask=attention_mask, |
|
return_attn_weights=return_attn_weights, |
|
is_causal=is_causal, |
|
) |
|
pooler_output = ( |
|
self.pooler(hidden_states=last_hidden_state, padding_mask=attention_mask) |
|
if self.pooler is not None |
|
else None |
|
) |
|
|
|
if not return_dict: |
|
return (last_hidden_state, pooler_output, attentions) |
|
|
|
return BacformerModelOutput( |
|
last_hidden_state=last_hidden_state, |
|
pooler_output=pooler_output, |
|
attentions=attentions, |
|
) |
|
|
|
|
|
class BacformerForCausalGM(BacformerPreTrainedModel): |
|
"""Bacformer model with genomic modeling head on top""" |
|
|
|
_tied_weights_keys = ["gm_head.decoder.weight"] |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.gm_head = BacformerGMHead(config) |
|
|
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Optional[BacformerModelOutput]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=None, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
is_causal=True, |
|
) |
|
last_hidden_state = outputs[0] |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(prediction_scores.device) |
|
|
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1]) |
|
labels = labels[:, 1:].contiguous().view(-1) |
|
loss = cross_entropy(shifted_prediction_scores, labels) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
prediction_scores, |
|
) + outputs |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=prediction_scores, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class BacformerForMaskedGM(BacformerPreTrainedModel): |
|
"""Bacformer model with genomic modeling head on top""" |
|
|
|
_tied_weights_keys = ["gm_head.decoder.weight"] |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.gm_head = BacformerGMHead(config) |
|
|
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Union[BacformerModelOutput, None]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
) |
|
last_hidden_state = outputs[0] |
|
|
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
last_hidden_state = last_hidden_state[labels != -100] |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
labels = labels.to(prediction_scores.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
labels = labels[labels != -100] |
|
loss = cross_entropy(prediction_scores, labels) |
|
else: |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
prediction_scores, |
|
) + outputs |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=prediction_scores, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class BacformerForCausalProteinFamilyModeling(BacformerPreTrainedModel): |
|
"""Bacformer model for causal modeling of protein families. Using protein family as tokens rather than protein embeddings""" |
|
|
|
_tied_weights_keys = ["gm_head.decoder.weight"] |
|
|
|
def __init__( |
|
self, |
|
config: BacformerConfig, |
|
n_conditional_properties: int = None, |
|
initialise_from_non_pfm_model: bool = False, |
|
): |
|
super().__init__(config) |
|
self.config = config |
|
self.cls_token_id = SPECIAL_TOKENS_DICT["CLS"] |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.gm_head = BacformerGMHead(config) |
|
|
|
if initialise_from_non_pfm_model: |
|
|
|
self.init_weights() |
|
|
|
|
|
self.bacformer.embeddings = BacformerProteinFamilyEmbeddings( |
|
config, |
|
protein_family_embeddings=self.gm_head.decoder.weight, |
|
token_type_embeddings=self.bacformer.embeddings.token_type_embeddings.weight, |
|
special_tokens_embeddings=self.bacformer.embeddings.special_tokens_embeddings.weight, |
|
n_conditional_properties=n_conditional_properties, |
|
) |
|
else: |
|
self.bacformer.embeddings = BacformerProteinFamilyEmbeddings( |
|
config, |
|
n_conditional_properties=n_conditional_properties, |
|
) |
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
labels: torch.Tensor = None, |
|
special_tokens_mask: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
property_ids: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Optional[BacformerModelOutput]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
outputs = self.bacformer( |
|
protein_embeddings=None, |
|
labels=labels, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
property_ids=property_ids, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
is_causal=True, |
|
) |
|
last_hidden_state = outputs[0] |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
|
|
loss = None |
|
if labels is not None: |
|
if property_ids is not None: |
|
labels = torch.cat( |
|
[ |
|
torch.tensor([-100], dtype=torch.long) |
|
.unsqueeze(0) |
|
.to(labels.device), |
|
labels, |
|
], |
|
dim=1, |
|
) |
|
labels = labels.to(prediction_scores.device) |
|
|
|
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous().view(-1, prediction_scores.shape[-1]) |
|
labels = labels[:, 1:].contiguous().view(-1) |
|
loss = cross_entropy(shifted_prediction_scores, labels) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
prediction_scores, |
|
) + outputs |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=prediction_scores, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def generate( |
|
self, |
|
protein_family_ids: torch.LongTensor, |
|
special_tokens_mask: torch.LongTensor = None, |
|
token_type_ids: torch.LongTensor = None, |
|
max_length: int = 6000, |
|
end_token_id: int = 50000, |
|
do_sample: bool = False, |
|
top_k: int = 50, |
|
top_p: float = 1.0, |
|
temperature: float = 1.0, |
|
property_ids: torch.LongTensor = None, |
|
return_last_hidden_states: bool = False, |
|
): |
|
""" |
|
Generate a sequence of tokens autoregressively from a given prompt. |
|
|
|
Args: |
|
protein_family_ids (torch.LongTensor): Tensor of shape (batch, seq_len) with token indices. |
|
max_length (int): Maximum length of the generated sequence (prompt + newly generated). |
|
end_token_id (int, optional): Token ID signifying end-of-sequence (END). |
|
If encountered, generation stops. |
|
do_sample (bool): Whether to sample from the probability distribution (True) |
|
or use greedy decoding (False). |
|
top_k (int): If >0, use top-k filtering in sampling mode. |
|
top_p (float): If <1.0, use nucleus (top-p) filtering in sampling mode. |
|
temperature (float): Softmax temperature for scaling logits. |
|
Higher => more random, lower => more deterministic. |
|
return_last_hidden_states (bool): If True, return final hidden states as well. |
|
|
|
Returns |
|
------- |
|
torch.LongTensor: The generated token sequence of shape (batch, final_seq_len). |
|
(Optional) torch.FloatTensor: Final hidden states of shape (batch, final_seq_len, hidden_dim) |
|
if `return_hidden_states=True`. |
|
""" |
|
|
|
if end_token_id is None: |
|
end_token_id = getattr(self, "end_token_id", None) |
|
|
|
|
|
self.eval() |
|
device = next(self.parameters()).device |
|
protein_family_ids = protein_family_ids.to(device) |
|
|
|
|
|
if special_tokens_mask is None: |
|
|
|
protein_family_ids = torch.cat( |
|
[torch.tensor([[-100]]).to(device), protein_family_ids], |
|
dim=1, |
|
) |
|
special_tokens_mask = [self.cls_token_id] + [self.config.prot_emb_token_id] * ( |
|
protein_family_ids.shape[1] - 1 |
|
) |
|
special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.long).to(device) |
|
|
|
|
|
if token_type_ids is None: |
|
token_type_ids = torch.zeros_like(protein_family_ids) |
|
|
|
|
|
generated = protein_family_ids.clone() |
|
batch_size, prompt_length = generated.shape |
|
max_new_tokens = max_length - prompt_length |
|
if max_new_tokens <= 0: |
|
max_new_tokens = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
for _step in range(max_new_tokens): |
|
|
|
logits = self.forward( |
|
labels=generated, |
|
special_tokens_mask=special_tokens_mask, |
|
|
|
token_type_ids=token_type_ids, |
|
property_ids=property_ids, |
|
return_dict=True, |
|
).logits |
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
if temperature != 1.0: |
|
next_token_logits = next_token_logits / temperature |
|
|
|
|
|
if do_sample: |
|
|
|
next_token_logits = top_k_filtering(next_token_logits, top_k=top_k) |
|
|
|
next_token_logits = top_p_filtering(next_token_logits, top_p=top_p) |
|
|
|
probs = softmax(next_token_logits, dim=-1) |
|
next_token_id = torch.multinomial(probs, num_samples=1) |
|
else: |
|
|
|
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
|
|
|
generated = torch.cat([generated, next_token_id], dim=1) |
|
special_tokens_mask = torch.cat( |
|
[special_tokens_mask, torch.tensor([[self.config.prot_emb_token_id]]).to(generated.device)], dim=1 |
|
) |
|
last_token_type_id = token_type_ids[:, -1].unsqueeze(1) |
|
token_type_ids = torch.cat([token_type_ids, last_token_type_id], dim=1) |
|
|
|
|
|
if end_token_id is not None: |
|
if (next_token_id.squeeze(1) == end_token_id).all(): |
|
|
|
break |
|
|
|
if not return_last_hidden_states: |
|
return generated |
|
|
|
|
|
if return_last_hidden_states: |
|
last_hidden_state = self.forward( |
|
labels=generated, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
return_dict=True, |
|
).last_hidden_state |
|
|
|
return generated, last_hidden_state |
|
|
|
|
|
class BacformerForMaskedGMWithContrastiveLoss(BacformerPreTrainedModel): |
|
"""Bacformer model with genomic modeling head on top""" |
|
|
|
_tied_weights_keys = ["gm_head.decoder.weight"] |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.gm_head = BacformerGMHead(config) |
|
|
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Union[BacformerModelOutput, None]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
) |
|
last_hidden_state = outputs[0] |
|
|
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
contrastive_loss = compute_contrastive_loss(protein_embeddings, last_hidden_state, special_tokens_mask) |
|
|
|
last_hidden_state = last_hidden_state[labels != -100] |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
labels = labels.to(prediction_scores.device) |
|
|
|
|
|
labels = labels[labels != -100] |
|
masked_loss = cross_entropy(prediction_scores, labels) |
|
loss = masked_loss + self.config.alpha_contrastive_loss * contrastive_loss |
|
else: |
|
prediction_scores = self.gm_head(last_hidden_state) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
prediction_scores, |
|
) + outputs |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=prediction_scores, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class BacformerForProteinClassification(BacformerPreTrainedModel): |
|
"""Bacformer model with a classification head on top for protein classification tasks.""" |
|
|
|
def __init__(self, config: BacformerConfig, benchmark_esm: bool = False): |
|
super().__init__(config) |
|
self.config = config |
|
self.benchmark_esm = benchmark_esm |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Optional[BacformerModelOutput]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
if self.benchmark_esm: |
|
outputs = [protein_embeddings] |
|
else: |
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
) |
|
|
|
last_hidden_state = outputs[0] |
|
|
|
last_hidden_state = self.dropout(last_hidden_state) |
|
logits = self.classifier(last_hidden_state) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
|
|
if self.config.problem_type == "regression": |
|
loss = mse_loss(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
elif ( |
|
self.config.problem_type == "multi_label_classification" |
|
or self.config.problem_type == "binary_classification" |
|
): |
|
|
|
mask = torch.ones_like(labels.view(-1)) - (labels.view(-1) == -100.0).float() |
|
loss = binary_cross_entropy_with_logits( |
|
logits.view(-1), labels.view(-1).type_as(logits), reduction="none" |
|
) |
|
loss = (loss * mask).sum() / mask.sum() |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
None, |
|
logits, |
|
) |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
last_hidden_state=last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class BacformerForGenomeClassification(BacformerPreTrainedModel): |
|
"""Bacformer model with a classification head on top for genome classification tasks.""" |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.classifier = BacformerGenomeClassificationHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Optional[BacformerModelOutput]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
return_attn_weights = ( |
|
return_attn_weights if return_attn_weights is not None else self.config.return_attn_weights |
|
) |
|
|
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
return_attn_weights=return_attn_weights, |
|
return_dict=return_dict, |
|
) |
|
last_hidden_state = outputs[0] |
|
logits = self.classifier(last_hidden_state, attention_mask) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
|
|
if self.config.problem_type == "regression": |
|
loss = mse_loss(logits.view(-1), labels.view(-1)) |
|
elif self.config.problem_type == "binary_classification": |
|
loss = binary_cross_entropy_with_logits(logits.view(-1), labels.view(-1).type_as(logits)) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss = cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss = binary_cross_entropy_with_logits(logits, labels) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
None, |
|
logits, |
|
) |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class BacformerForProteinProteinInteraction(BacformerPreTrainedModel): |
|
"""Bacformer model with a protein-protein interaction head on top.""" |
|
|
|
def __init__(self, config: BacformerConfig, benchmark_esm: bool = False): |
|
super().__init__(config) |
|
self.config = config |
|
self.benchmark_esm = benchmark_esm |
|
print("Benchmark ESM:", self.benchmark_esm) |
|
self.return_attn_weights = config.return_attn_weights |
|
|
|
self.bacformer = BacformerModel(config, add_pooling_layer=False) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.dense = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.hidden_size), |
|
nn.GELU(), |
|
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps), |
|
nn.Dropout(0.2), |
|
) |
|
self.ppi_head = BacformerProteinProteinInteractionHead( |
|
in_features=config.hidden_size, prot_emb_idx=config.prot_emb_token_id |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
protein_embeddings: torch.Tensor, |
|
special_tokens_mask: torch.Tensor, |
|
labels: torch.Tensor = None, |
|
token_type_ids: torch.Tensor = None, |
|
attention_mask: torch.Tensor = None, |
|
return_attn_weights: bool = None, |
|
return_dict: Union[bool, None] = None, |
|
) -> Union[OrderedDict, None]: |
|
"""Forward method for the model.""" |
|
return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
|
if self.benchmark_esm: |
|
last_hidden_state = protein_embeddings.squeeze(0)[1:-2, :] |
|
else: |
|
outputs = self.bacformer( |
|
protein_embeddings=protein_embeddings, |
|
special_tokens_mask=special_tokens_mask, |
|
token_type_ids=token_type_ids, |
|
attention_mask=attention_mask, |
|
return_attn_weights=False, |
|
return_dict=True, |
|
) |
|
last_hidden_state = outputs.last_hidden_state.squeeze(0)[1:-2, :] |
|
|
|
assert labels.shape[0] == 1, "Batch size should be 1 for protein-protein interaction task" |
|
|
|
last_hidden_state = self.dense(self.dropout(last_hidden_state)) |
|
last_hidden_state = torch.cat([last_hidden_state[labels[:, 0]], last_hidden_state[labels[:, 1]]], dim=0).mean( |
|
dim=0 |
|
) |
|
logits = self.ppi_head(last_hidden_state) |
|
|
|
loss = binary_cross_entropy_with_logits(logits, labels[:, 2].type_as(logits).squeeze(0)) |
|
|
|
if not return_dict: |
|
return ( |
|
loss, |
|
logits, |
|
) |
|
|
|
return BacformerModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
last_hidden_state=outputs.last_hidden_state, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
|
|
class BacformerPooler(nn.Module): |
|
"""Pooler for Bacformer model.""" |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, hidden_states: torch.Tensor, padding_mask: torch.Tensor = None) -> torch.Tensor: |
|
"""Forward method for the pooler.""" |
|
|
|
padding_mask = padding_mask.to(hidden_states.device) if padding_mask is not None else None |
|
if padding_mask is not None: |
|
mean_hidden_states = torch.einsum("ijk,ij->ik", hidden_states, padding_mask) / padding_mask.sum( |
|
1 |
|
).unsqueeze(1) |
|
else: |
|
mean_hidden_states = hidden_states.mean(dim=1) |
|
pooled_output = self.dense(mean_hidden_states) |
|
pooled_output = self.activation(pooled_output) |
|
return pooled_output |
|
|
|
|
|
class BacformerGMHead(nn.Module): |
|
"""Bacformer Head for genomic modeling.""" |
|
|
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.protein_clusters_vocab_size + 1, bias=False) |
|
self.bias = nn.Parameter(torch.zeros(config.protein_clusters_vocab_size + 1)) |
|
|
|
def forward(self, features, **kwargs): |
|
"""Forward method for the head.""" |
|
x = self.dense(features) |
|
x = gelu(x) |
|
x = self.layer_norm(x) |
|
|
|
|
|
x = self.decoder(x) + self.bias |
|
return x |
|
|
|
|
|
class BacformerGenomeClassificationHead(nn.Module): |
|
"""Head for genome-level classification tasks.""" |
|
|
|
def __init__(self, config: BacformerConfig): |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.out_proj = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features: torch.Tensor, padding_mask: torch.Tensor, **kwargs): |
|
"""Forward method for the head.""" |
|
if padding_mask is not None: |
|
x = torch.einsum("ijk,ij->ik", features, padding_mask) / padding_mask.sum(1).unsqueeze(1) |
|
else: |
|
x = features[:, 0, :] |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class BacformerProteinProteinInteractionHead(nn.Module): |
|
"""Head for protein-protein interaction task at a genome level.""" |
|
|
|
def __init__(self, in_features: int, prot_emb_idx: int = 4, bias: bool = True): |
|
super().__init__() |
|
self.in_features = in_features |
|
self.prot_emb_idx = prot_emb_idx |
|
self.dropout = nn.Dropout(0.2) |
|
self.linear = nn.Linear(in_features, 1, bias=bias) |
|
|
|
def forward( |
|
self, hidden_states: torch.Tensor |
|
) -> torch.Tensor: |
|
"""Forward method for the head.""" |
|
return self.linear(self.dropout(hidden_states)).squeeze(-1) |
|
|