Spaces:
Sleeping
Sleeping
| import sys | |
| import torch | |
| import torch.nn as nn | |
| sys.path.append("../") | |
| class GDANet(torch.nn.Module): | |
| def __init__( | |
| self, | |
| prot_encoder, | |
| disease_encoder, | |
| ): | |
| """_summary_ | |
| Args: | |
| prot_encoder (_type_): _description_ | |
| disease_encoder (_type_): _description_ | |
| prot_out_dim (int, optional): _description_. Defaults to 1024. | |
| disease_out_dim (int, optional): _description_. Defaults to 768. | |
| drop_out (int, optional): _description_. Defaults to 0. | |
| freeze_prot_encoder (bool, optional): _description_. Defaults to True. | |
| freeze_disease_encoder (bool, optional): _description_. Defaults to True. | |
| """ | |
| super(GDANet, self).__init__() | |
| self.prot_encoder = prot_encoder | |
| self.disease_encoder = disease_encoder | |
| self.cls = None | |
| self.reg = None | |
| def add_regression_head(self, prot_out_dim=1024, disease_out_dim=768): | |
| """Add regression head. | |
| Args: | |
| prot_out_dim (_type_): protein encoder output dimension. | |
| disease_out_dim (_type_): disease encoder output dimension. | |
| drop_out (int, optional): dropout rate. Defaults to 0. | |
| """ | |
| self.reg = nn.Linear(prot_out_dim + disease_out_dim, 1) | |
| def add_classification_head( | |
| self, prot_out_dim=1024, disease_out_dim=768, out_dim=2 | |
| ): | |
| """Add classification head. | |
| Args: | |
| prot_out_dim (_type_): protein encoder output dimension. | |
| disease_out_dim (_type_): disease encoder output dimension. | |
| out_dim (int, optional): output dimension. Defaults to 2. | |
| drop_out (int, optional): dropout rate. Defaults to 0. | |
| """ | |
| self.cls = nn.Linear(prot_out_dim + disease_out_dim, out_dim) | |
| def freeze_encoders(self, freeze_prot_encoder, freeze_disease_encoder): | |
| """Freeze encoders. | |
| Args: | |
| freeze_prot_encoder (boolean): freeze protein encoder | |
| freeze_disease_encoder (boolean): freeze disease textual encoder | |
| """ | |
| if freeze_prot_encoder: | |
| for param in self.prot_encoder.parameters(): | |
| param.requires_grad = False | |
| else: | |
| for param in self.disease_encoder.parameters(): | |
| param.requires_grad = True | |
| if freeze_disease_encoder: | |
| for param in self.disease_encoder.parameters(): | |
| param.requires_grad = False | |
| else: | |
| for param in self.disease_encoder.parameters(): | |
| param.requires_grad = True | |
| print(f"freeze_prot_encoder:{freeze_prot_encoder}") | |
| print(f"freeze_disease_encoder:{freeze_disease_encoder}") | |