Otter-Knowledge
Collection
Datasets and models of the Otter-Knowledge project
•
24 items
•
Updated
•
1
Otter-Knoweldge model trained using only one modality for molecules: SMI-TED (SMI)
Otter models are based on Graph Neural Networks (GNN) that propagates initial embeddings through a set of layers that upgrade input embedding according to the node neighbours. The architecture of GNN consists of two main blocks: encoder and decoder.
Model training data:
The model was trained over Uniprot-BindingDB
Paper or resources for more information:
License:
MIT
Where to send questions or comments about the model:
Clone the repo:
git clone https://github.com/IBM/otter-knowledge.git
cd otter-knowledge
import torch
from torch import nn
class BindingAffinity(nn.Module):
def __init__(self, gnn, drug_modality):
super(BindingAffinity, self).__init__()
self.drug_modality = drug_modality
self.protein_modality = 'protein-sequence-mean'
self.drug_entity_name = 'Drug'
self.protein_entity_name = 'Protein'
self.drug_rel_id = 1
self.protein_rel_id = 2
self.protein_drug_rel_id = 0
self.gnn = gnn
self.device = 'cpu'
hd1 = 512
num_input = 2
self.combine = torch.nn.ModuleList([nn.Linear(num_input * hd1, hd1), nn.ReLU(),
nn.Linear(hd1, hd1), nn.ReLU(),
nn.Linear(hd1, 1)])
self.to(self.device)
def forward(self, drug_embedding, protein_embedding):
nodes = {
self.drug_modality: {
'embeddings': drug_embedding.unsqueeze(0).to(self.device),
'node_indices': torch.tensor([1]).to(self.device)
},
self.drug_entity_name: {
'embeddings': [None],
'node_indices': torch.tensor([0]).to(self.device)
},
self.protein_modality: {
'embeddings': protein_embedding.unsqueeze(0).to(self.device),
'node_indices': torch.tensor([3]).to(self.device)
},
self.protein_entity_name: {
'embeddings': [None],
'node_indices': torch.tensor([2]).to(self.device)
}
}
triples = torch.tensor([[1, 3],
[3, 4],
[0, 2]]).to(self.device)
gnn_embeddings = self.gnn.encoder(nodes, triples)
node_gnn_embeddings = []
all_indices = [0, 2]
for indices in all_indices:
node_gnn_embedding = torch.index_select(gnn_embeddings, dim=0, index=torch.tensor(indices).to(self.device))
node_gnn_embeddings.append(node_gnn_embedding)
c = torch.cat(node_gnn_embeddings, dim=-1)
for m in self.combine:
c = m(c)
return c```
- Run the inference with the initial embeddings (embeddings obtained after using the handlers (SMI-TED, ESM1b) over the SMILES and the protein sequence):
```python
p = net(drug_embedding=drug_embedding, protein_embedding=protein_embedding)
print(p)```