Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .multihead_attention import MultiheadAttention # noqa | |
| from .axial_attention import ColumnSelfAttention, RowSelfAttention | |
| def gelu(x): | |
| """Implementation of the gelu activation function. | |
| For information: OpenAI GPT's gelu is slightly different | |
| (and gives slightly different results): | |
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
| """ | |
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
| def symmetrize(x): | |
| "Make layer symmetric in final two dimensions, used for contact prediction." | |
| return x + x.transpose(-1, -2) | |
| def apc(x): | |
| "Perform average product correct, used for contact prediction." | |
| a1 = x.sum(-1, keepdims=True) | |
| a2 = x.sum(-2, keepdims=True) | |
| a12 = x.sum((-1, -2), keepdims=True) | |
| avg = a1 * a2 | |
| avg.div_(a12) # in-place to reduce memory | |
| normalized = x - avg | |
| return normalized | |
| class ESM1LayerNorm(nn.Module): | |
| def __init__(self, hidden_size, eps=1e-12, affine=True): | |
| """Construct a layernorm layer in the TF style (eps inside the sqrt).""" | |
| super().__init__() | |
| self.hidden_size = (hidden_size,) if isinstance(hidden_size, int) else tuple(hidden_size) | |
| self.eps = eps | |
| self.affine = bool(affine) | |
| if self.affine: | |
| self.weight = nn.Parameter(torch.ones(hidden_size)) | |
| self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
| else: | |
| self.weight, self.bias = None, None | |
| def forward(self, x): | |
| dims = tuple(-(i + 1) for i in range(len(self.hidden_size))) | |
| means = x.mean(dims, keepdim=True) | |
| x_zeromean = x - means | |
| variances = x_zeromean.pow(2).mean(dims, keepdim=True) | |
| x = x_zeromean / torch.sqrt(variances + self.eps) | |
| if self.affine: | |
| x = (self.weight * x) + self.bias | |
| return x | |
| try: | |
| from apex.normalization import FusedLayerNorm as _FusedLayerNorm | |
| class ESM1bLayerNorm(_FusedLayerNorm): | |
| def forward(self, x): | |
| if not x.is_cuda: | |
| return super().forward(x) | |
| else: | |
| with torch.cuda.device(x.device): | |
| return super().forward(x) | |
| except ImportError: | |
| from torch.nn import LayerNorm as ESM1bLayerNorm | |
| class TransformerLayer(nn.Module): | |
| """Transformer layer block.""" | |
| def __init__( | |
| self, | |
| embed_dim, | |
| ffn_embed_dim, | |
| attention_heads, | |
| add_bias_kv=True, | |
| use_esm1b_layer_norm=False, | |
| use_rotary_embeddings: bool = False, | |
| ): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.ffn_embed_dim = ffn_embed_dim | |
| self.attention_heads = attention_heads | |
| self.use_rotary_embeddings = use_rotary_embeddings | |
| self._init_submodules(add_bias_kv, use_esm1b_layer_norm) | |
| def _init_submodules(self, add_bias_kv, use_esm1b_layer_norm): | |
| BertLayerNorm = ESM1bLayerNorm if use_esm1b_layer_norm else ESM1LayerNorm | |
| self.self_attn = MultiheadAttention( | |
| self.embed_dim, | |
| self.attention_heads, | |
| add_bias_kv=add_bias_kv, | |
| add_zero_attn=False, | |
| use_rotary_embeddings=self.use_rotary_embeddings, | |
| ) | |
| self.self_attn_layer_norm = BertLayerNorm(self.embed_dim) | |
| self.fc1 = nn.Linear(self.embed_dim, self.ffn_embed_dim) | |
| self.fc2 = nn.Linear(self.ffn_embed_dim, self.embed_dim) | |
| self.final_layer_norm = BertLayerNorm(self.embed_dim) | |
| def forward( | |
| self, x, self_attn_mask=None, self_attn_padding_mask=None, need_head_weights=False | |
| ): | |
| residual = x | |
| x = self.self_attn_layer_norm(x) | |
| x, attn = self.self_attn( | |
| query=x, | |
| key=x, | |
| value=x, | |
| key_padding_mask=self_attn_padding_mask, | |
| need_weights=True, | |
| need_head_weights=need_head_weights, | |
| attn_mask=self_attn_mask, | |
| ) | |
| x = residual + x | |
| residual = x | |
| x = self.final_layer_norm(x) | |
| x = gelu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = residual + x | |
| return x, attn | |
| class AxialTransformerLayer(nn.Module): | |
| """Implements an Axial MSA Transformer block.""" | |
| def __init__( | |
| self, | |
| embedding_dim: int = 768, | |
| ffn_embedding_dim: int = 3072, | |
| num_attention_heads: int = 8, | |
| dropout: float = 0.1, | |
| attention_dropout: float = 0.1, | |
| activation_dropout: float = 0.1, | |
| max_tokens_per_msa: int = 2**14, | |
| ) -> None: | |
| super().__init__() | |
| # Initialize parameters | |
| self.embedding_dim = embedding_dim | |
| self.dropout_prob = dropout | |
| row_self_attention = RowSelfAttention( | |
| embedding_dim, | |
| num_attention_heads, | |
| dropout=dropout, | |
| max_tokens_per_msa=max_tokens_per_msa, | |
| ) | |
| column_self_attention = ColumnSelfAttention( | |
| embedding_dim, | |
| num_attention_heads, | |
| dropout=dropout, | |
| max_tokens_per_msa=max_tokens_per_msa, | |
| ) | |
| feed_forward_layer = FeedForwardNetwork( | |
| embedding_dim, | |
| ffn_embedding_dim, | |
| activation_dropout=activation_dropout, | |
| max_tokens_per_msa=max_tokens_per_msa, | |
| ) | |
| self.row_self_attention = self.build_residual(row_self_attention) | |
| self.column_self_attention = self.build_residual(column_self_attention) | |
| self.feed_forward_layer = self.build_residual(feed_forward_layer) | |
| def build_residual(self, layer: nn.Module): | |
| return NormalizedResidualBlock( | |
| layer, | |
| self.embedding_dim, | |
| self.dropout_prob, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| self_attn_mask: Optional[torch.Tensor] = None, | |
| self_attn_padding_mask: Optional[torch.Tensor] = None, | |
| need_head_weights: bool = False, | |
| ): | |
| """ | |
| LayerNorm is applied either before or after the self-attention/ffn | |
| modules similar to the original Transformer implementation. | |
| """ | |
| x, row_attn = self.row_self_attention( | |
| x, | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask, | |
| ) | |
| x, column_attn = self.column_self_attention( | |
| x, | |
| self_attn_mask=self_attn_mask, | |
| self_attn_padding_mask=self_attn_padding_mask, | |
| ) | |
| x = self.feed_forward_layer(x) | |
| if need_head_weights: | |
| return x, column_attn, row_attn | |
| else: | |
| return x | |
| class LearnedPositionalEmbedding(nn.Embedding): | |
| """ | |
| This module learns positional embeddings up to a fixed maximum size. | |
| Padding ids are ignored by either offsetting based on padding_idx | |
| or by setting padding_idx to None and ensuring that the appropriate | |
| position ids are passed to the forward function. | |
| """ | |
| def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): | |
| if padding_idx is not None: | |
| num_embeddings_ = num_embeddings + padding_idx + 1 | |
| else: | |
| num_embeddings_ = num_embeddings | |
| super().__init__(num_embeddings_, embedding_dim, padding_idx) | |
| self.max_positions = num_embeddings | |
| def forward(self, input: torch.Tensor): | |
| """Input is expected to be of size [bsz x seqlen].""" | |
| if input.size(1) > self.max_positions: | |
| raise ValueError( | |
| f"Sequence length {input.size(1)} above maximum " | |
| f" sequence length of {self.max_positions}" | |
| ) | |
| mask = input.ne(self.padding_idx).int() | |
| positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + self.padding_idx | |
| return F.embedding( | |
| positions, | |
| self.weight, | |
| self.padding_idx, | |
| self.max_norm, | |
| self.norm_type, | |
| self.scale_grad_by_freq, | |
| self.sparse, | |
| ) | |
| class SinusoidalPositionalEmbedding(nn.Module): | |
| def __init__(self, embed_dim, padding_idx, learned=False): | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.padding_idx = padding_idx | |
| self.register_buffer("_float_tensor", torch.FloatTensor(1)) | |
| self.weights = None | |
| def forward(self, x): | |
| bsz, seq_len = x.shape | |
| max_pos = self.padding_idx + 1 + seq_len | |
| if self.weights is None or max_pos > self.weights.size(0): | |
| self.weights = self.get_embedding(max_pos) | |
| self.weights = self.weights.type_as(self._float_tensor) | |
| positions = self.make_positions(x) | |
| return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() | |
| def make_positions(self, x): | |
| mask = x.ne(self.padding_idx) | |
| range_buf = torch.arange(x.size(1), device=x.device).expand_as(x) + self.padding_idx + 1 | |
| positions = range_buf.expand_as(x) | |
| return positions * mask.long() + self.padding_idx * (1 - mask.long()) | |
| def get_embedding(self, num_embeddings): | |
| half_dim = self.embed_dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) | |
| emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) | |
| if self.embed_dim % 2 == 1: | |
| # zero pad | |
| emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
| if self.padding_idx is not None: | |
| emb[self.padding_idx, :] = 0 | |
| return emb | |
| class RobertaLMHead(nn.Module): | |
| """Head for masked language modeling.""" | |
| def __init__(self, embed_dim, output_dim, weight): | |
| super().__init__() | |
| self.dense = nn.Linear(embed_dim, embed_dim) | |
| self.layer_norm = ESM1bLayerNorm(embed_dim) | |
| self.weight = weight | |
| self.bias = nn.Parameter(torch.zeros(output_dim)) | |
| def forward(self, features): | |
| x = self.dense(features) | |
| x = gelu(x) | |
| x = self.layer_norm(x) | |
| # project back to size of vocabulary with bias | |
| x = F.linear(x, self.weight) + self.bias | |
| return x | |
| class ContactPredictionHead(nn.Module): | |
| """Performs symmetrization, apc, and computes a logistic regression on the output features""" | |
| def __init__( | |
| self, | |
| in_features: int, | |
| prepend_bos: bool, | |
| append_eos: bool, | |
| bias=True, | |
| eos_idx: Optional[int] = None, | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.prepend_bos = prepend_bos | |
| self.append_eos = append_eos | |
| if append_eos and eos_idx is None: | |
| raise ValueError("Using an alphabet with eos token, but no eos token was passed in.") | |
| self.eos_idx = eos_idx | |
| self.regression = nn.Linear(in_features, 1, bias) | |
| self.activation = nn.Sigmoid() | |
| def forward(self, tokens, attentions): | |
| # remove eos token attentions | |
| if self.append_eos: | |
| eos_mask = tokens.ne(self.eos_idx).to(attentions) | |
| eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2) | |
| attentions = attentions * eos_mask[:, None, None, :, :] | |
| attentions = attentions[..., :-1, :-1] | |
| # remove cls token attentions | |
| if self.prepend_bos: | |
| attentions = attentions[..., 1:, 1:] | |
| batch_size, layers, heads, seqlen, _ = attentions.size() | |
| attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen) | |
| # features: B x C x T x T | |
| attentions = attentions.to( | |
| self.regression.weight.device | |
| ) # attentions always float32, may need to convert to float16 | |
| attentions = apc(symmetrize(attentions)) | |
| attentions = attentions.permute(0, 2, 3, 1) | |
| return self.activation(self.regression(attentions).squeeze(3)) | |
| class NormalizedResidualBlock(nn.Module): | |
| def __init__( | |
| self, | |
| layer: nn.Module, | |
| embedding_dim: int, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.layer = layer | |
| self.dropout_module = nn.Dropout( | |
| dropout, | |
| ) | |
| self.layer_norm = ESM1bLayerNorm(self.embedding_dim) | |
| def forward(self, x, *args, **kwargs): | |
| residual = x | |
| x = self.layer_norm(x) | |
| outputs = self.layer(x, *args, **kwargs) | |
| if isinstance(outputs, tuple): | |
| x, *out = outputs | |
| else: | |
| x = outputs | |
| out = None | |
| x = self.dropout_module(x) | |
| x = residual + x | |
| if out is not None: | |
| return (x,) + tuple(out) | |
| else: | |
| return x | |
| class FeedForwardNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| embedding_dim: int, | |
| ffn_embedding_dim: int, | |
| activation_dropout: float = 0.1, | |
| max_tokens_per_msa: int = 2**14, | |
| ): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.ffn_embedding_dim = ffn_embedding_dim | |
| self.max_tokens_per_msa = max_tokens_per_msa | |
| self.activation_fn = nn.GELU() | |
| self.activation_dropout_module = nn.Dropout( | |
| activation_dropout, | |
| ) | |
| self.fc1 = nn.Linear(embedding_dim, ffn_embedding_dim) | |
| self.fc2 = nn.Linear(ffn_embedding_dim, embedding_dim) | |
| def forward(self, x): | |
| x = self.activation_fn(self.fc1(x)) | |
| x = self.activation_dropout_module(x) | |
| x = self.fc2(x) | |
| return x | |