Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import sys | |
| sys.path.append("../") | |
| from pytorch_metric_learning.distances import CosineSimilarity | |
| from pytorch_metric_learning.reducers import ThresholdReducer | |
| from pytorch_metric_learning.regularizers import LpRegularizer | |
| from pytorch_metric_learning import losses | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from pytorch_metric_learning import losses, miners | |
| from torch.cuda.amp import autocast | |
| from torch.nn import Module | |
| from tqdm import tqdm | |
| from utils.gd_model import GDANet | |
| from torch.nn import MultiheadAttention | |
| from transformers import BertModel | |
| from transformers import EsmModel, EsmConfig | |
| LOGGER = logging.getLogger(__name__) | |
| class FusionModule(nn.Module): | |
| def __init__(self, out_dim, num_head, dropout= 0.1): | |
| super(FusionModule, self).__init__() | |
| """FusionModule. | |
| Args: | |
| dropout= 0.1 is defaut | |
| out_dim: model output dimension | |
| num_head = 8: Multi-head Attention | |
| """ | |
| self.out_dim = out_dim | |
| self.num_head = num_head | |
| self.WqS = nn.Linear(out_dim, out_dim) | |
| self.WkS = nn.Linear(out_dim, out_dim) | |
| self.WvS = nn.Linear(out_dim, out_dim) | |
| self.WqT = nn.Linear(out_dim, out_dim) | |
| self.WkT = nn.Linear(out_dim, out_dim) | |
| self.WvT = nn.Linear(out_dim, out_dim) | |
| self.multi_head_attention = nn.MultiheadAttention(out_dim, num_head, dropout=dropout) | |
| def forward(self, zs, zt): | |
| # nn.MultiheadAttention The input representation is (token_length, batch_size, out_dim) | |
| # zs = protein_representation.permute(1, 0, 2) | |
| # zt = disease_representation.permute(1, 0, 2) | |
| # Compute query, key and value representations | |
| qs = self.WqS(zs) | |
| ks = self.WkS(zs) | |
| vs = self.WvS(zs) | |
| qt = self.WqT(zt) | |
| kt = self.WkT(zt) | |
| vt = self.WvT(zt) | |
| #self.multi_head_attention() The function returns two values: the representation and the attention weight matrix, computed after multiple attentions. In this case, we only care about the computed representation and not the attention weight matrix, so "_" is used to indicate that we do not intend to use or store the second return value. | |
| zs_attention1, _ = self.multi_head_attention(qs, ks, vs) | |
| zs_attention2, _ = self.multi_head_attention(qs, kt, vt) | |
| zt_attention1, _ = self.multi_head_attention(qt, kt, vt) | |
| zt_attention2, _ = self.multi_head_attention(qt, ks, vs) | |
| protein_fused = 0.5 * (zs_attention1 + zs_attention2) | |
| dis_fused = 0.5 * (zt_attention1 + zt_attention2) | |
| return protein_fused, dis_fused | |
| class CrossAttentionBlock(nn.Module): | |
| def __init__(self, hidden_dim, num_heads): | |
| super(CrossAttentionBlock, self).__init__() | |
| if hidden_dim % num_heads != 0: | |
| raise ValueError( | |
| "The hidden size (%d) is not a multiple of the number of attention " | |
| "heads (%d)" % (hidden_dim, num_heads)) | |
| self.hidden_dim = hidden_dim | |
| self.num_heads = num_heads | |
| self.head_size = hidden_dim // num_heads | |
| self.query1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.key1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.value1 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.query2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.key2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| self.value2 = nn.Linear(hidden_dim, hidden_dim, bias=False) | |
| def _alpha_from_logits(self, logits, mask_row, mask_col, inf=1e6): | |
| N, L1, L2, H = logits.shape | |
| mask_row = mask_row.view(N, L1, 1).repeat(1, 1, H) | |
| mask_col = mask_col.view(N, L2, 1).repeat(1, 1, H) | |
| mask_pair = torch.einsum('blh, bkh->blkh', mask_row, mask_col) | |
| logits = torch.where(mask_pair, logits, logits - inf) | |
| alpha = torch.softmax(logits, dim=2) | |
| mask_row = mask_row.view(N, L1, 1, H).repeat(1, 1, L2, 1) | |
| alpha = torch.where(mask_row, alpha, torch.zeros_like(alpha)) | |
| return alpha | |
| def _heads(self, x, n_heads, n_ch): | |
| s = list(x.size())[:-1] + [n_heads, n_ch] | |
| return x.view(*s) | |
| def forward(self, input1, input2, mask1, mask2): | |
| query1 = self._heads(self.query1(input1), self.num_heads, self.head_size) | |
| key1 = self._heads(self.key1(input1), self.num_heads, self.head_size) | |
| query2 = self._heads(self.query2(input2), self.num_heads, self.head_size) | |
| key2 = self._heads(self.key2(input2), self.num_heads, self.head_size) | |
| logits11 = torch.einsum('blhd, bkhd->blkh', query1, key1) | |
| logits12 = torch.einsum('blhd, bkhd->blkh', query1, key2) | |
| logits21 = torch.einsum('blhd, bkhd->blkh', query2, key1) | |
| logits22 = torch.einsum('blhd, bkhd->blkh', query2, key2) | |
| alpha11 = self._alpha_from_logits(logits11, mask1, mask1) | |
| alpha12 = self._alpha_from_logits(logits12, mask1, mask2) | |
| alpha21 = self._alpha_from_logits(logits21, mask2, mask1) | |
| alpha22 = self._alpha_from_logits(logits22, mask2, mask2) | |
| value1 = self._heads(self.value1(input1), self.num_heads, self.head_size) | |
| value2 = self._heads(self.value2(input2), self.num_heads, self.head_size) | |
| output1 = (torch.einsum('blkh, bkhd->blhd', alpha11, value1).flatten(-2) + | |
| torch.einsum('blkh, bkhd->blhd', alpha12, value2).flatten(-2)) / 2 | |
| output2 = (torch.einsum('blkh, bkhd->blhd', alpha21, value1).flatten(-2) + | |
| torch.einsum('blkh, bkhd->blhd', alpha22, value2).flatten(-2)) / 2 | |
| return output1, output2 | |
| class GDA_Metric_Learning(GDANet): | |
| def __init__( | |
| self, prot_encoder, disease_encoder, prot_out_dim, disease_out_dim, args | |
| ): | |
| """Constructor for the model. | |
| Args: | |
| prot_encoder (_type_): Protein encoder. | |
| disease_encoder (_type_): Disease Textual encoder. | |
| prot_out_dim (_type_): Dimension of the Protein encoder. | |
| disease_out_dim (_type_): Dimension of the Disease encoder. | |
| args (_type_): _description_ | |
| """ | |
| super(GDA_Metric_Learning, self).__init__( | |
| prot_encoder, | |
| disease_encoder, | |
| ) | |
| self.prot_encoder = prot_encoder | |
| self.disease_encoder = disease_encoder | |
| self.loss = args.loss | |
| self.use_miner = args.use_miner | |
| self.miner_margin = args.miner_margin | |
| self.agg_mode = args.agg_mode | |
| self.prot_reg = nn.Linear(prot_out_dim, 1024) | |
| # self.prot_reg = nn.Linear(prot_out_dim, disease_out_dim) | |
| self.dis_reg = nn.Linear(disease_out_dim, 1024) | |
| # self.prot_adapter_name = None | |
| # self.disease_adapter_name = None | |
| self.fusion_layer = FusionModule(1024, num_head=8) | |
| self.cross_attention_layer = CrossAttentionBlock(1024, 8) | |
| # # MMP Prediction Heads | |
| # self.prot_pred_head = nn.Sequential( | |
| # nn.Linear(disease_out_dim, disease_out_dim), | |
| # nn.ReLU(), | |
| # nn.Linear(disease_out_dim, 1280) #vocabulary size : prot model tokenize length 30 446 | |
| # ) | |
| # self.dise_pred_head = nn.Sequential( | |
| # nn.Linear(disease_out_dim, disease_out_dim), | |
| # nn.ReLU(), | |
| # nn.Linear(disease_out_dim, 768) #vocabulary size : disease model tokenize length 30522 | |
| # ) | |
| if self.use_miner: | |
| self.miner = miners.TripletMarginMiner( | |
| margin=args.miner_margin, type_of_triplets="all" | |
| ) | |
| else: | |
| self.miner = None | |
| if self.loss == "ms_loss": | |
| self.loss = losses.MultiSimilarityLoss( | |
| alpha=2, beta=50, base=0.5 | |
| ) # 1,2,3; 40,50,60 | |
| #1_40=1.5141 50=1.4988 60=1.4905 2_60=1.1786 50=1.1874 40=1.2008 3_40=1.1146 50=1.1012 | |
| elif self.loss == "circle_loss": | |
| self.loss = losses.CircleLoss( | |
| m=0.4, gamma=80 | |
| ) | |
| elif self.loss == "triplet_loss": | |
| self.loss = losses.TripletMarginLoss( | |
| margin=0.05, swap=False, smooth_loss=False, | |
| triplets_per_anchor="all") | |
| # distance = CosineSimilarity(), | |
| # reducer = ThresholdReducer(high=0.3), | |
| # embedding_regularizer = LpRegularizer() ) | |
| elif self.loss == "infoNCE": | |
| self.loss = losses.NTXentLoss( | |
| temperature=0.07 | |
| ) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
| elif self.loss == "lifted_structure_loss": | |
| self.loss = losses.LiftedStructureLoss( | |
| neg_margin=1, pos_margin=0 | |
| ) | |
| elif self.loss == "nca_loss": | |
| self.loss = losses.NCALoss( | |
| softmax_scale=1 | |
| ) | |
| self.fusion = False | |
| # self.stack = False | |
| self.dropout = torch.nn.Dropout(args.dropout) | |
| print("miner:", self.miner) | |
| print("loss:", self.loss) | |
| # def add_fusion(self): | |
| # adapter_setup = Fuse("prot_adapter", "disease_adapter") | |
| # self.prot_encoder.add_fusion(adapter_setup) | |
| # self.prot_encoder.set_active_adapters(adapter_setup) | |
| # self.prot_encoder.train_fusion(adapter_setup) | |
| # self.disease_encoder.add_fusion(adapter_setup) | |
| # self.disease_encoder.set_active_adapters(adapter_setup) | |
| # self.disease_encoder.train_fusion(adapter_setup) | |
| # self.fusion = True | |
| # def add_stack_gda(self, reduction_factor): | |
| # self.add_gda_adapters(reduction_factor=reduction_factor) | |
| # # adapter_setup = Fuse("prot_adapter", "disease_adapter") | |
| # self.prot_encoder.active_adapters = Stack( | |
| # self.prot_adapter_name, self.gda_adapter_name | |
| # ) | |
| # self.disease_encoder.active_adapters = Stack( | |
| # self.disease_adapter_name, self.gda_adapter_name | |
| # ) | |
| # print("stacked adapters loaded.") | |
| # self.stack = True | |
| # def load_adapters( | |
| # self, | |
| # prot_model_path, | |
| # disease_model_path, | |
| # prot_adapter_name="prot_adapter", | |
| # disease_adapter_name="disease_adapter", | |
| # ): | |
| # if os.path.exists(prot_model_path): | |
| # print(f"loading prot adapter from: {prot_model_path}") | |
| # self.prot_adapter_name = prot_adapter_name | |
| # self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name) | |
| # self.prot_encoder.set_active_adapters(prot_adapter_name) | |
| # print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}") | |
| # else: | |
| # print(f"{prot_model_path} not exits") | |
| # if os.path.exists(disease_model_path): | |
| # print(f"loading prot adapter from: {disease_model_path}") | |
| # self.disease_adapter_name = disease_adapter_name | |
| # self.disease_encoder.load_adapter( | |
| # disease_model_path, load_as=disease_adapter_name | |
| # ) | |
| # self.disease_encoder.set_active_adapters(disease_adapter_name) | |
| # print( | |
| # f"load disease adapters from: {disease_model_path} {disease_adapter_name}" | |
| # ) | |
| # else: | |
| # print(f"{disease_model_path} not exits") | |
| def non_adapters( | |
| self, | |
| prot_model_path, | |
| disease_model_path, | |
| ): | |
| if os.path.exists(prot_model_path): | |
| # Load the entire model for prot_model | |
| prot_model = torch.load(prot_model_path) | |
| # Set the prot_encoder to the loaded model | |
| self.prot_encoder = prot_model.prot_encoder | |
| print(f"load protein from: {prot_model_path}") | |
| else: | |
| print(f"{prot_model_path} not exits") | |
| if os.path.exists(disease_model_path): | |
| # Load the entire model for disease_model | |
| disease_model = torch.load(disease_model_path) | |
| # Set the disease_encoder to the loaded model | |
| self.disease_encoder = disease_model.disease_encoder | |
| print(f"load disease from: {disease_model_path}") | |
| else: | |
| print(f"{disease_model_path} not exits") | |
| # def add_gda_adapters( | |
| # self, | |
| # gda_adapter_name="gda_adapter", | |
| # reduction_factor=16, | |
| # ): | |
| # """Initialise adapters | |
| # Args: | |
| # prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
| # disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
| # reduction_factor (int, optional): _description_. Defaults to 16. | |
| # """ | |
| # adapter_config = AdapterConfig.load( | |
| # "pfeiffer", reduction_factor=reduction_factor | |
| # ) | |
| # self.gda_adapter_name = gda_adapter_name | |
| # self.prot_encoder.add_adapter(gda_adapter_name, config=adapter_config) | |
| # self.prot_encoder.train_adapter([gda_adapter_name]) | |
| # self.disease_encoder.add_adapter(gda_adapter_name, config=adapter_config) | |
| # self.disease_encoder.train_adapter([gda_adapter_name]) | |
| # def init_adapters( | |
| # self, | |
| # prot_adapter_name="gda_prot_adapter", | |
| # disease_adapter_name="gda_disease_adapter", | |
| # reduction_factor=16, | |
| # ): | |
| # """Initialise adapters | |
| # Args: | |
| # prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
| # disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
| # reduction_factor (int, optional): _description_. Defaults to 16. | |
| # """ | |
| # adapter_config = AdapterConfig.load( | |
| # "pfeiffer", reduction_factor=reduction_factor | |
| # ) | |
| # self.prot_adapter_name = prot_adapter_name | |
| # self.disease_adapter_name = disease_adapter_name | |
| # self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config) | |
| # self.prot_encoder.train_adapter([prot_adapter_name]) | |
| # self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config) | |
| # self.disease_encoder.train_adapter([disease_adapter_name]) | |
| # print(f"adapter modules initialized") | |
| # def save_adapters(self, save_path_prefix, total_step): | |
| # """Save adapters into file. | |
| # Args: | |
| # save_path_prefix (string): saving path prefix. | |
| # total_step (int): total step number. | |
| # """ | |
| # prot_save_dir = os.path.join( | |
| # save_path_prefix, f"prot_adapter_step_{total_step}" | |
| # )# adapter | |
| # disease_save_dir = os.path.join( | |
| # save_path_prefix, f"disease_adapter_step_{total_step}" | |
| # ) | |
| # os.makedirs(prot_save_dir, exist_ok=True) | |
| # os.makedirs(disease_save_dir, exist_ok=True) | |
| # self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name) | |
| # prot_head_save_path = os.path.join(prot_save_dir, "prot_head.bin") | |
| # torch.save(self.prot_reg, prot_head_save_path) | |
| # self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name) | |
| # disease_head_save_path = os.path.join(prot_save_dir, "disease_head.bin") | |
| # torch.save(self.prot_reg, disease_head_save_path) | |
| # if self.fusion: | |
| # self.prot_encoder.save_all_adapters(prot_save_dir) | |
| # self.disease_encoder.save_all_adapters(disease_save_dir) | |
| def predict(self, query_toks1, query_toks2): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| # Extract input_ids and attention_mask for protein | |
| prot_input_ids = query_toks1["input_ids"] | |
| prot_attention_mask = query_toks1["attention_mask"] | |
| # Extract input_ids and attention_mask for dis | |
| dis_input_ids = query_toks2["input_ids"] | |
| dis_attention_mask = query_toks2["attention_mask"] | |
| # Process inputs through encoders | |
| last_hidden_state1 = self.prot_encoder( | |
| input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state1 = self.prot_reg(last_hidden_state1) | |
| last_hidden_state2 = self.disease_encoder( | |
| input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state2 = self.dis_reg(last_hidden_state2) | |
| # Apply the cross-attention layer | |
| prot_fused, dis_fused = self.cross_attention_layer( | |
| last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask | |
| ) | |
| # last_hidden_state1 = self.prot_encoder( | |
| # query_toks1, return_dict=True | |
| # ).last_hidden_state | |
| # last_hidden_state1 = self.prot_reg( | |
| # last_hidden_state1 | |
| # ) # transform the prot embedding into the same dimension as the disease embedding | |
| # last_hidden_state2 = self.disease_encoder( | |
| # query_toks2, return_dict=True | |
| # ).last_hidden_state | |
| # last_hidden_state2 = self.dis_reg( | |
| # last_hidden_state2 | |
| # ) # transform the disease embedding into 1024 | |
| # Apply the fusion layer and Recovery of representational shape | |
| # prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2) | |
| if self.agg_mode == "cls": | |
| query_embed1 = prot_fused[:, 0] # query : [batch_size, hidden] | |
| query_embed2 = dis_fused[:, 0] # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean_all_tok": | |
| query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden] | |
| query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean": | |
| query_embed1 = ( | |
| prot_fused * query_toks1["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
| query_embed2 = ( | |
| dis_fused * query_toks2["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
| else: | |
| raise NotImplementedError() | |
| query_embed = torch.cat([query_embed1, query_embed2], dim=1) | |
| return query_embed | |
| def forward(self, query_toks1, query_toks2, labels): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| # Extract input_ids and attention_mask for protein | |
| prot_input_ids = query_toks1["input_ids"] | |
| prot_attention_mask = query_toks1["attention_mask"] | |
| # Extract input_ids and attention_mask for dis | |
| dis_input_ids = query_toks2["input_ids"] | |
| dis_attention_mask = query_toks2["attention_mask"] | |
| # Process inputs through encoders | |
| last_hidden_state1 = self.prot_encoder( | |
| input_ids=prot_input_ids, attention_mask=prot_attention_mask, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state1 = self.prot_reg(last_hidden_state1) | |
| last_hidden_state2 = self.disease_encoder( | |
| input_ids=dis_input_ids, attention_mask=dis_attention_mask, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state2 = self.dis_reg(last_hidden_state2) | |
| # Apply the cross-attention layer | |
| prot_fused, dis_fused = self.cross_attention_layer( | |
| last_hidden_state1, last_hidden_state2, prot_attention_mask, dis_attention_mask | |
| ) | |
| # last_hidden_state1 = self.prot_encoder( | |
| # query_toks1, return_dict=True | |
| # ).last_hidden_state | |
| # last_hidden_state1 = self.prot_reg( | |
| # last_hidden_state1 | |
| # ) # transform the prot embedding into the same dimension as the disease embedding | |
| # last_hidden_state2 = self.disease_encoder( | |
| # query_toks2, return_dict=True | |
| # ).last_hidden_state | |
| # last_hidden_state2 = self.dis_reg( | |
| # last_hidden_state2 | |
| # ) # transform the disease embedding into 1024 | |
| # # Apply the fusion layer and Recovery of representational shape | |
| # prot_fused, dis_fused = self.fusion_layer(last_hidden_state1, last_hidden_state2) | |
| if self.agg_mode == "cls": | |
| query_embed1 = prot_pred[:, 0] # query : [batch_size, hidden] | |
| query_embed2 = dise_pred[:, 0] # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean_all_tok": | |
| query_embed1 = prot_fused.mean(1) # query : [batch_size, hidden] | |
| query_embed2 = dis_fused.mean(1) # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean": | |
| query_embed1 = ( | |
| prot_pred * query_toks1["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
| query_embed2 = ( | |
| dis_fused * query_toks2["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
| else: | |
| raise NotImplementedError() | |
| # print("query_embed1 =", query_embed1.shape, "query_embed2 =", query_embed2.shape) | |
| query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
| # print("query_embed =", len(query_embed)) | |
| labels = torch.cat([torch.arange(len(labels)), torch.arange(len(labels))], dim=0) | |
| if self.use_miner: | |
| hard_pairs = self.miner(query_embed, labels) | |
| return self.loss(query_embed, labels, hard_pairs)# + loss_mmp | |
| else: | |
| loss = self.loss(query_embed, labels)# + loss_mmp | |
| # print('loss :', loss) | |
| return loss | |
| def get_embeddings(self, mentions, batch_size=1024): | |
| """ | |
| Compute all embeddings from mention tokens. | |
| """ | |
| embedding_table = [] | |
| with torch.no_grad(): | |
| for start in tqdm(range(0, len(mentions), batch_size)): | |
| end = min(start + batch_size, len(mentions)) | |
| batch = mentions[start:end] | |
| batch_embedding = self.vectorizer(batch) | |
| batch_embedding = batch_embedding.cpu() | |
| embedding_table.append(batch_embedding) | |
| embedding_table = torch.cat(embedding_table, dim=0) | |
| return embedding_table | |
| class DDA_Metric_Learning(Module): | |
| def __init__(self, disease_encoder, args): | |
| """Constructor for the model. | |
| Args: | |
| disease_encoder (_type_): disease encoder. | |
| args (_type_): _description_ | |
| """ | |
| super(DDA_Metric_Learning, self).__init__() | |
| self.disease_encoder = disease_encoder | |
| self.loss = args.loss | |
| self.use_miner = args.use_miner | |
| self.miner_margin = args.miner_margin | |
| self.agg_mode = args.agg_mode | |
| self.disease_adapter_name = None | |
| if self.use_miner: | |
| self.miner = miners.TripletMarginMiner( | |
| margin=args.miner_margin, type_of_triplets="all" | |
| ) | |
| else: | |
| self.miner = None | |
| if self.loss == "ms_loss": | |
| self.loss = losses.MultiSimilarityLoss( | |
| alpha=1, beta=60, base=0.5 | |
| ) # 1,2,3; 40,50,60 | |
| elif self.loss == "circle_loss": | |
| self.loss = losses.CircleLoss() | |
| elif self.loss == "triplet_loss": | |
| self.loss = losses.TripletMarginLoss() | |
| elif self.loss == "infoNCE": | |
| self.loss = losses.NTXentLoss( | |
| temperature=0.07 | |
| ) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
| elif self.loss == "lifted_structure_loss": | |
| self.loss = losses.LiftedStructureLoss() | |
| elif self.loss == "nca_loss": | |
| self.loss = losses.NCALoss() | |
| self.reg = None | |
| self.cls = None | |
| self.dropout = torch.nn.Dropout(args.dropout) | |
| print("miner:", self.miner) | |
| print("loss:", self.loss) | |
| def add_classification_head(self, disease_out_dim=768, out_dim=2): | |
| """Add regression head. | |
| Args: | |
| disease_out_dim (_type_): disease encoder output dimension. | |
| out_dim (int, optional): output dimension. Defaults to 2. | |
| drop_out (int, optional): dropout rate. Defaults to 0. | |
| """ | |
| self.cls = nn.Linear(disease_out_dim * 2, out_dim) | |
| def load_disease_adapter( | |
| self, | |
| disease_model_path, | |
| disease_adapter_name="disease_adapter", | |
| ): | |
| if os.path.exists(disease_model_path): | |
| self.disease_adapter_name = disease_adapter_name | |
| self.disease_encoder.load_adapter( | |
| disease_model_path, load_as=disease_adapter_name | |
| ) | |
| self.disease_encoder.set_active_adapters(disease_adapter_name) | |
| print( | |
| f"load disease adapters from: {disease_model_path} {disease_adapter_name}" | |
| ) | |
| else: | |
| print(f"{disease_adapter_name} not exits") | |
| def init_adapters( | |
| self, | |
| disease_adapter_name="disease_adapter", | |
| reduction_factor=16, | |
| ): | |
| """Initialise adapters | |
| Args: | |
| disease_adapter_name (str, optional): _description_. Defaults to "disease_adapter". | |
| reduction_factor (int, optional): _description_. Defaults to 16. | |
| """ | |
| adapter_config = AdapterConfig.load( | |
| "pfeiffer", reduction_factor=reduction_factor | |
| ) | |
| self.disease_adapter_name = disease_adapter_name | |
| self.disease_encoder.add_adapter(disease_adapter_name, config=adapter_config) | |
| self.disease_encoder.train_adapter([disease_adapter_name]) | |
| def save_adapters(self, save_path_prefix, total_step): | |
| """Save adapters into file. | |
| Args: | |
| save_path_prefix (string): saving path prefix. | |
| total_step (int): total step number. | |
| """ | |
| disease_save_dir = os.path.join( | |
| save_path_prefix, f"disease_adapter_step_{total_step}" | |
| ) | |
| os.makedirs(disease_save_dir, exist_ok=True) | |
| self.disease_encoder.save_adapter(disease_save_dir, self.disease_adapter_name) | |
| def predict(self, x1, x2): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| if self.agg_mode == "cls": | |
| x1 = self.disease_encoder(x1).last_hidden_state[:, 0] | |
| x2 = self.disease_encoder(x2).last_hidden_state[:, 0] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| else: | |
| x1 = self.disease_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x2 = self.disease_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| def module_predict(self, x1, x2): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| if self.agg_mode == "cls": | |
| x1 = self.disease_encoder.module(x1).last_hidden_state[:, 0] | |
| x2 = self.disease_encoder.module(x2).last_hidden_state[:, 0] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| else: | |
| x1 = self.disease_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x2 = self.disease_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| def forward(self, query_toks1, query_toks2, labels): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| last_hidden_state1 = self.disease_encoder( | |
| **query_toks1, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state2 = self.disease_encoder( | |
| **query_toks2, return_dict=True | |
| ).last_hidden_state | |
| if self.agg_mode == "cls": | |
| query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden] | |
| query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean_all_tok": | |
| query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden] | |
| query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean": | |
| query_embed1 = ( | |
| last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
| query_embed2 = ( | |
| last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
| else: | |
| raise NotImplementedError() | |
| query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
| labels = torch.cat([labels, labels], dim=0) | |
| if self.use_miner: | |
| hard_pairs = self.miner(query_embed, labels) | |
| print('miner used') | |
| return self.loss(query_embed, labels, hard_pairs) | |
| else: | |
| print('no miner') | |
| return self.loss(query_embed, labels) | |
| class PPI_Metric_Learning(Module): | |
| def __init__(self, prot_encoder, args): | |
| """Constructor for the model. | |
| Args: | |
| prot_encoder (_type_): Protein encoder. | |
| prot_encoder (_type_): prot Textual encoder. | |
| prot_out_dim (_type_): Dimension of the Protein encoder. | |
| prot_out_dim (_type_): Dimension of the prot encoder. | |
| args (_type_): _description_ | |
| """ | |
| super(PPI_Metric_Learning, self).__init__() | |
| self.prot_encoder = prot_encoder | |
| self.loss = args.loss | |
| self.use_miner = args.use_miner | |
| self.miner_margin = args.miner_margin | |
| self.agg_mode = args.agg_mode | |
| self.prot_adapter_name = None | |
| if self.use_miner: | |
| self.miner = miners.TripletMarginMiner( | |
| margin=args.miner_margin, type_of_triplets="all" | |
| ) | |
| else: | |
| self.miner = None | |
| if self.loss == "ms_loss": | |
| self.loss = losses.MultiSimilarityLoss( | |
| alpha=1, beta=60, base=0.5 | |
| ) # 1,2,3; 40,50,60 | |
| elif self.loss == "circle_loss": | |
| self.loss = losses.CircleLoss() | |
| elif self.loss == "triplet_loss": | |
| self.loss = losses.TripletMarginLoss() | |
| elif self.loss == "infoNCE": | |
| self.loss = losses.NTXentLoss( | |
| temperature=0.07 | |
| ) # The MoCo paper uses 0.07, while SimCLR uses 0.5. | |
| elif self.loss == "lifted_structure_loss": | |
| self.loss = losses.LiftedStructureLoss() | |
| elif self.loss == "nca_loss": | |
| self.loss = losses.NCALoss() | |
| self.reg = None | |
| self.cls = None | |
| self.dropout = torch.nn.Dropout(args.dropout) | |
| print("miner:", self.miner) | |
| print("loss:", self.loss) | |
| def add_classification_head(self, prot_out_dim=1024, out_dim=2): | |
| """Add regression head. | |
| Args: | |
| prot_out_dim (_type_): protein encoder output dimension. | |
| disease_out_dim (_type_): disease encoder output dimension. | |
| out_dim (int, optional): output dimension. Defaults to 2. | |
| drop_out (int, optional): dropout rate. Defaults to 0. | |
| """ | |
| self.cls = nn.Linear(prot_out_dim + prot_out_dim, out_dim) | |
| def load_prot_adapter( | |
| self, | |
| prot_model_path, | |
| prot_adapter_name="prot_adapter", | |
| ): | |
| if os.path.exists(prot_model_path): | |
| self.prot_adapter_name = prot_adapter_name | |
| self.prot_encoder.load_adapter(prot_model_path, load_as=prot_adapter_name) | |
| self.prot_encoder.set_active_adapters(prot_adapter_name) | |
| print(f"load protein adapters from: {prot_model_path} {prot_adapter_name}") | |
| else: | |
| print(f"{prot_model_path} not exits") | |
| def init_adapters( | |
| self, | |
| prot_adapter_name="prot_adapter", | |
| reduction_factor=16, | |
| ): | |
| """Initialise adapters | |
| Args: | |
| prot_adapter_name (str, optional): _description_. Defaults to "prot_adapter". | |
| reduction_factor (int, optional): _description_. Defaults to 16. | |
| """ | |
| adapter_config = AdapterConfig.load( | |
| "pfeiffer", reduction_factor=reduction_factor | |
| ) | |
| self.prot_adapter_name = prot_adapter_name | |
| self.prot_encoder.add_adapter(prot_adapter_name, config=adapter_config) | |
| self.prot_encoder.train_adapter([prot_adapter_name]) | |
| def save_adapters(self, save_path_prefix, total_step): | |
| """Save adapters into file. | |
| Args: | |
| save_path_prefix (string): saving path prefix. | |
| total_step (int): total step number. | |
| """ | |
| prot_save_dir = os.path.join( | |
| save_path_prefix, f"prot_adapter_step_{total_step}" | |
| ) | |
| os.makedirs(prot_save_dir, exist_ok=True) | |
| self.prot_encoder.save_adapter(prot_save_dir, self.prot_adapter_name) | |
| def predict(self, x1, x2): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| if self.agg_mode == "cls": | |
| x1 = self.prot_encoder(x1).last_hidden_state[:, 0] | |
| x2 = self.prot_encoder(x2).last_hidden_state[:, 0] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| else: | |
| x1 = self.prot_encoder(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x2 = self.prot_encoder(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| def module_predict(self, x1, x2): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| if self.agg_mode == "cls": | |
| x1 = self.prot_encoder.module(x1).last_hidden_state[:, 0] | |
| x2 = self.prot_encoder.module(x2).last_hidden_state[:, 0] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| else: | |
| x1 = self.prot_encoder.module(x1).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x2 = self.prot_encoder.module(x2).last_hidden_state.mean(1) # query : [batch_size, hidden] | |
| x = torch.cat((x1, x2), 1) | |
| return x | |
| def forward(self, query_toks1, query_toks2, labels): | |
| """ | |
| query : (N, h), candidates : (N, topk, h) | |
| output : (N, topk) | |
| """ | |
| last_hidden_state1 = self.prot_encoder( | |
| **query_toks1, return_dict=True | |
| ).last_hidden_state | |
| last_hidden_state2 = self.prot_encoder( | |
| **query_toks2, return_dict=True | |
| ).last_hidden_state | |
| if self.agg_mode == "cls": | |
| query_embed1 = last_hidden_state1[:, 0] # query : [batch_size, hidden] | |
| query_embed2 = last_hidden_state2[:, 0] # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean_all_tok": | |
| query_embed1 = last_hidden_state1.mean(1) # query : [batch_size, hidden] | |
| query_embed2 = last_hidden_state2.mean(1) # query : [batch_size, hidden] | |
| elif self.agg_mode == "mean": | |
| query_embed1 = ( | |
| last_hidden_state1 * query_toks1["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks1["attention_mask"].sum(-1).unsqueeze(-1) | |
| query_embed2 = ( | |
| last_hidden_state2 * query_toks2["attention_mask"].unsqueeze(-1) | |
| ).sum(1) / query_toks2["attention_mask"].sum(-1).unsqueeze(-1) | |
| else: | |
| raise NotImplementedError() | |
| query_embed = torch.cat([query_embed1, query_embed2], dim=0) | |
| labels = torch.cat([labels, labels], dim=0) | |
| if self.use_miner: | |
| hard_pairs = self.miner(query_embed, labels) | |
| return self.loss(query_embed, labels, hard_pairs) | |
| else: | |
| return self.loss(query_embed, labels) | |