|
--- |
|
license: mit |
|
inference: false |
|
datasets: |
|
- ibm-research/otter_uniprot_bindingdb |
|
--- |
|
|
|
# Otter UB UM Model Card |
|
|
|
Otter-Knoweldge model trained using only one modality for molecules: Unimo (UM) |
|
|
|
## Model details |
|
Otter models are based on Graph Neural Networks (GNN) that propagates initial embeddings through a set of layers that upgrade input embedding according to the node neighbours. |
|
The architecture of GNN consists of two main blocks: encoder and decoder. |
|
- For encoder we first define a projection layer which consists of a set of linear transformations for each node modality and projects nodes into common dimensionality, then we apply several multi-relational graph convolutional layers (R-GCN) which distinguish between different types of edges between source and target nodes by having a set of trainable parameters for each edge type. |
|
- For decoder we consider link prediction task, which consists of a scoring function that maps each triple of source and target nodes and the corresponding edge and maps that to a scalar number defined over interval [0; 1]. |
|
|
|
|
|
**Model training data:** |
|
|
|
The model was trained over *Uniprot-BindingDB* |
|
|
|
|
|
**Paper or resources for more information:** |
|
- [GitHub Repo](https://github.com/IBM/otter-knowledge) |
|
- [Paper](https://arxiv.org/abs/2306.12802) |
|
|
|
**License:** |
|
|
|
MIT |
|
|
|
**Where to send questions or comments about the model:** |
|
- [GitHub Repo](https://github.com/IBM/otter-knowledge) |
|
|
|
## How to use |
|
|
|
Clone the repo: |
|
```sh |
|
git clone https://github.com/IBM/otter-knowledge.git |
|
cd otter-knowledge |
|
``` |
|
|
|
- Use the BindingAffinity Class: |
|
|
|
```python |
|
import torch |
|
from torch import nn |
|
|
|
|
|
class BindingAffinity(nn.Module): |
|
|
|
def __init__(self, gnn, drug_modality): |
|
super(BindingAffinity, self).__init__() |
|
self.drug_modality = drug_modality |
|
self.protein_modality = 'protein-sequence-mean' |
|
self.drug_entity_name = 'Drug' |
|
self.protein_entity_name = 'Protein' |
|
self.drug_rel_id = 1 |
|
self.protein_rel_id = 2 |
|
self.protein_drug_rel_id = 0 |
|
self.gnn = gnn |
|
self.device = 'cpu' |
|
hd1 = 512 |
|
num_input = 2 |
|
self.combine = torch.nn.ModuleList([nn.Linear(num_input * hd1, hd1), nn.ReLU(), |
|
nn.Linear(hd1, hd1), nn.ReLU(), |
|
nn.Linear(hd1, 1)]) |
|
self.to(self.device) |
|
|
|
def forward(self, drug_embedding, protein_embedding): |
|
nodes = { |
|
self.drug_modality: { |
|
'embeddings': drug_embedding.unsqueeze(0).to(self.device), |
|
'node_indices': torch.tensor([1]).to(self.device) |
|
}, |
|
self.drug_entity_name: { |
|
'embeddings': [None], |
|
'node_indices': torch.tensor([0]).to(self.device) |
|
}, |
|
self.protein_modality: { |
|
'embeddings': protein_embedding.unsqueeze(0).to(self.device), |
|
'node_indices': torch.tensor([3]).to(self.device) |
|
}, |
|
self.protein_entity_name: { |
|
'embeddings': [None], |
|
'node_indices': torch.tensor([2]).to(self.device) |
|
} |
|
} |
|
triples = torch.tensor([[1, 3], |
|
[3, 4], |
|
[0, 2]]).to(self.device) |
|
gnn_embeddings = self.gnn.encoder(nodes, triples) |
|
node_gnn_embeddings = [] |
|
all_indices = [0, 2] |
|
|
|
for indices in all_indices: |
|
node_gnn_embedding = torch.index_select(gnn_embeddings, dim=0, index=torch.tensor(indices).to(self.device)) |
|
node_gnn_embeddings.append(node_gnn_embedding) |
|
|
|
c = torch.cat(node_gnn_embeddings, dim=-1) |
|
for m in self.combine: |
|
c = m(c) |
|
|
|
return c``` |
|
|
|
- Run the inference with the initial embeddings (embeddings obtained after using the handlers (UniMol, ESM1b) over the SMILES and the protein sequence): |
|
|
|
```python |
|
p = net(drug_embedding=drug_embedding, protein_embedding=protein_embedding) |
|
print(p)``` |