Spaces:
Running
Running
import torch | |
from torch.nn import Linear, Dropout, Module, Sequential | |
from torch_geometric.nn import GINEConv, global_mean_pool | |
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder | |
class HybridGNN(Module): | |
def __init__(self, gnn_dim, rdkit_dim, hidden_dim, dropout_rate=0.2, activation='ReLU'): | |
super().__init__() | |
act_map = {'Swish': torch.nn.SiLU(), 'ReLU': torch.nn.ReLU()} | |
act_fn = act_map[activation] | |
self.gnn_dim = gnn_dim | |
self.rdkit_dim = rdkit_dim | |
self.atom_encoder = AtomEncoder(emb_dim=gnn_dim) | |
self.bond_encoder = BondEncoder(emb_dim=gnn_dim) | |
self.conv1 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim))) | |
self.conv2 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim))) | |
self.pool = global_mean_pool | |
self.mlp = Sequential(Linear(gnn_dim + rdkit_dim, hidden_dim), act_fn, | |
Dropout(dropout_rate), | |
Linear(hidden_dim, hidden_dim // 2), act_fn, | |
Dropout(dropout_rate), | |
Linear(hidden_dim // 2, 1)) | |
def forward(self, data): | |
x = self.atom_encoder(data.x) | |
edge_attr = self.bond_encoder(data.edge_attr) | |
x = self.conv1(x, data.edge_index, edge_attr) | |
x = self.conv2(x, data.edge_index, edge_attr) | |
x = self.pool(x, data.batch) | |
rdkit_feats = getattr(data, 'rdkit_feats', None) | |
if rdkit_feats is not None: | |
if x.shape[0] != rdkit_feats.shape[0]: | |
raise ValueError(f"Shape mismatch: GNN output ({x.shape}) vs rdkit_feats ({rdkit_feats.shape})") | |
x = torch.cat([x, rdkit_feats], dim=1) | |
else: | |
raise ValueError("RDKit features not found in the data object.") | |
return self.mlp(x) | |
def load_model(rdkit_dim: int, path: str = "best_hybridgnn.pt", device: str = "cpu"): | |
model = HybridGNN(gnn_dim=512, rdkit_dim=rdkit_dim, hidden_dim=256, dropout_rate=0.29, activation='Swish') | |
model.load_state_dict(torch.load(path, map_location=device)) | |
model.to(device) | |
model.eval() | |
return model | |