File size: 4,045 Bytes
11e13cb
 
 
 
 
 
 
 
 
a0ab2da
11e13cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
---
license: mit
inference: false
datasets:
- ibm-research/otter_uniprot_bindingdb
---

# Otter UB MF Model Card

Otter-Knoweldge model trained using only one modality for molecules: morgan-fingerprint (MF)

## Model details
Otter models are based on Graph Neural Networks (GNN) that propagates initial embeddings through a set of layers that upgrade input embedding according to the node neighbours. 
The architecture of GNN consists of two main blocks: encoder and decoder. 
- For encoder we first define a projection layer which consists of a set of linear transformations for each node modality and projects nodes into common dimensionality, then we apply several multi-relational graph convolutional layers (R-GCN) which distinguish between different types of edges between source and target nodes by having a set of trainable parameters for each edge type. 
- For decoder we consider link prediction task, which consists of a scoring function that maps each triple of source and target nodes and the corresponding edge and maps that to a scalar number defined over interval [0; 1].


**Model training data:**

The model was trained over *Uniprot-BindingDB*


**Paper or resources for more information:**
- [GitHub Repo](https://github.com/IBM/otter-knowledge)
- [Paper](https://arxiv.org/abs/2306.12802)

**License:**

MIT

**Where to send questions or comments about the model:**
- [GitHub Repo](https://github.com/IBM/otter-knowledge)

## How to use

Clone the repo:
```sh
git clone https://github.com/IBM/otter-knowledge.git
cd otter-knowledge
```

- Use the BindingAffinity Class:
  
```python
import torch
from torch import nn


class BindingAffinity(nn.Module):

    def __init__(self, gnn, drug_modality):
        super(BindingAffinity, self).__init__()
        self.drug_modality = drug_modality
        self.protein_modality = 'protein-sequence-mean'
        self.drug_entity_name = 'Drug'
        self.protein_entity_name = 'Protein'
        self.drug_rel_id = 1
        self.protein_rel_id = 2
        self.protein_drug_rel_id = 0
        self.gnn = gnn
        self.device = 'cpu'
        hd1 = 512
        num_input = 2
        self.combine = torch.nn.ModuleList([nn.Linear(num_input * hd1, hd1), nn.ReLU(),
                                            nn.Linear(hd1, hd1), nn.ReLU(),
                                            nn.Linear(hd1, 1)])
        self.to(self.device)

    def forward(self, drug_embedding, protein_embedding):
        nodes = {
            self.drug_modality: {
                'embeddings': drug_embedding.unsqueeze(0).to(self.device),
                'node_indices': torch.tensor([1]).to(self.device)
            },
            self.drug_entity_name: {
                'embeddings': [None],
                'node_indices': torch.tensor([0]).to(self.device)
            },
            self.protein_modality: {
                'embeddings': protein_embedding.unsqueeze(0).to(self.device),
                'node_indices': torch.tensor([3]).to(self.device)
            },
            self.protein_entity_name: {
                'embeddings': [None],
                'node_indices': torch.tensor([2]).to(self.device)
            }
        }
        triples = torch.tensor([[1, 3],
                                [3, 4],
                                [0, 2]]).to(self.device)
        gnn_embeddings = self.gnn.encoder(nodes, triples)
        node_gnn_embeddings = []
        all_indices = [0, 2]

        for indices in all_indices:
            node_gnn_embedding = torch.index_select(gnn_embeddings, dim=0, index=torch.tensor(indices).to(self.device))
            node_gnn_embeddings.append(node_gnn_embedding)

        c = torch.cat(node_gnn_embeddings, dim=-1)
        for m in self.combine:
            c = m(c)

        return c```

- Run the inference with the initial embeddings (embeddings obtained after using the handlers (Morgan-fingerprint, ESM1b) over the SMILES and the protein sequence):

```python
p = net(drug_embedding=drug_embedding, protein_embedding=protein_embedding)
print(p)```