marmg commited on
Commit
11e13cb
·
verified ·
1 Parent(s): 238dc36

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +109 -3
README.md CHANGED
@@ -1,3 +1,109 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ inference: false
4
+ datasets:
5
+ - ibm-research/otter_uniprot_bindingdb
6
+ ---
7
+
8
+ # Otter UB MF Model Card
9
+
10
+
11
+ ## Model details
12
+ Otter models are based on Graph Neural Networks (GNN) that propagates initial embeddings through a set of layers that upgrade input embedding according to the node neighbours.
13
+ The architecture of GNN consists of two main blocks: encoder and decoder.
14
+ - For encoder we first define a projection layer which consists of a set of linear transformations for each node modality and projects nodes into common dimensionality, then we apply several multi-relational graph convolutional layers (R-GCN) which distinguish between different types of edges between source and target nodes by having a set of trainable parameters for each edge type.
15
+ - For decoder we consider link prediction task, which consists of a scoring function that maps each triple of source and target nodes and the corresponding edge and maps that to a scalar number defined over interval [0; 1].
16
+
17
+
18
+ **Model training data:**
19
+
20
+ The model was trained over *Uniprot-BindingDB*
21
+
22
+
23
+ **Paper or resources for more information:**
24
+ - [GitHub Repo](https://github.com/IBM/otter-knowledge)
25
+ - [Paper](https://arxiv.org/abs/2306.12802)
26
+
27
+ **License:**
28
+
29
+ MIT
30
+
31
+ **Where to send questions or comments about the model:**
32
+ - [GitHub Repo](https://github.com/IBM/otter-knowledge)
33
+
34
+ ## How to use
35
+
36
+ Clone the repo:
37
+ ```sh
38
+ git clone https://github.com/IBM/otter-knowledge.git
39
+ cd otter-knowledge
40
+ ```
41
+
42
+ - Use the BindingAffinity Class:
43
+
44
+ ```python
45
+ import torch
46
+ from torch import nn
47
+
48
+
49
+ class BindingAffinity(nn.Module):
50
+
51
+ def __init__(self, gnn, drug_modality):
52
+ super(BindingAffinity, self).__init__()
53
+ self.drug_modality = drug_modality
54
+ self.protein_modality = 'protein-sequence-mean'
55
+ self.drug_entity_name = 'Drug'
56
+ self.protein_entity_name = 'Protein'
57
+ self.drug_rel_id = 1
58
+ self.protein_rel_id = 2
59
+ self.protein_drug_rel_id = 0
60
+ self.gnn = gnn
61
+ self.device = 'cpu'
62
+ hd1 = 512
63
+ num_input = 2
64
+ self.combine = torch.nn.ModuleList([nn.Linear(num_input * hd1, hd1), nn.ReLU(),
65
+ nn.Linear(hd1, hd1), nn.ReLU(),
66
+ nn.Linear(hd1, 1)])
67
+ self.to(self.device)
68
+
69
+ def forward(self, drug_embedding, protein_embedding):
70
+ nodes = {
71
+ self.drug_modality: {
72
+ 'embeddings': drug_embedding.unsqueeze(0).to(self.device),
73
+ 'node_indices': torch.tensor([1]).to(self.device)
74
+ },
75
+ self.drug_entity_name: {
76
+ 'embeddings': [None],
77
+ 'node_indices': torch.tensor([0]).to(self.device)
78
+ },
79
+ self.protein_modality: {
80
+ 'embeddings': protein_embedding.unsqueeze(0).to(self.device),
81
+ 'node_indices': torch.tensor([3]).to(self.device)
82
+ },
83
+ self.protein_entity_name: {
84
+ 'embeddings': [None],
85
+ 'node_indices': torch.tensor([2]).to(self.device)
86
+ }
87
+ }
88
+ triples = torch.tensor([[1, 3],
89
+ [3, 4],
90
+ [0, 2]]).to(self.device)
91
+ gnn_embeddings = self.gnn.encoder(nodes, triples)
92
+ node_gnn_embeddings = []
93
+ all_indices = [0, 2]
94
+
95
+ for indices in all_indices:
96
+ node_gnn_embedding = torch.index_select(gnn_embeddings, dim=0, index=torch.tensor(indices).to(self.device))
97
+ node_gnn_embeddings.append(node_gnn_embedding)
98
+
99
+ c = torch.cat(node_gnn_embeddings, dim=-1)
100
+ for m in self.combine:
101
+ c = m(c)
102
+
103
+ return c```
104
+
105
+ - Run the inference with the initial embeddings (embeddings obtained after using the handlers (Morgan-fingerprint, ESM1b) over the SMILES and the protein sequence):
106
+
107
+ ```python
108
+ p = net(drug_embedding=drug_embedding, protein_embedding=protein_embedding)
109
+ print(p)```