import torch from torch.nn import Linear, Dropout, Module, Sequential from torch_geometric.nn import GINEConv, global_mean_pool from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder class HybridGNN(Module): def __init__(self, gnn_dim, rdkit_dim, hidden_dim, dropout_rate=0.2, activation='ReLU'): super().__init__() act_map = {'Swish': torch.nn.SiLU(), 'ReLU': torch.nn.ReLU()} act_fn = act_map[activation] self.gnn_dim = gnn_dim self.rdkit_dim = rdkit_dim self.atom_encoder = AtomEncoder(emb_dim=gnn_dim) self.bond_encoder = BondEncoder(emb_dim=gnn_dim) self.conv1 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim))) self.conv2 = GINEConv(Sequential(Linear(gnn_dim, gnn_dim), act_fn, Linear(gnn_dim, gnn_dim))) self.pool = global_mean_pool self.mlp = Sequential(Linear(gnn_dim + rdkit_dim, hidden_dim), act_fn, Dropout(dropout_rate), Linear(hidden_dim, hidden_dim // 2), act_fn, Dropout(dropout_rate), Linear(hidden_dim // 2, 1)) def forward(self, data): x = self.atom_encoder(data.x) edge_attr = self.bond_encoder(data.edge_attr) x = self.conv1(x, data.edge_index, edge_attr) x = self.conv2(x, data.edge_index, edge_attr) x = self.pool(x, data.batch) rdkit_feats = getattr(data, 'rdkit_feats', None) if rdkit_feats is not None: if x.shape[0] != rdkit_feats.shape[0]: raise ValueError(f"Shape mismatch: GNN output ({x.shape}) vs rdkit_feats ({rdkit_feats.shape})") x = torch.cat([x, rdkit_feats], dim=1) else: raise ValueError("RDKit features not found in the data object.") return self.mlp(x) def load_model(rdkit_dim: int, path: str = "best_hybridgnn.pt", device: str = "cpu"): model = HybridGNN(gnn_dim=512, rdkit_dim=rdkit_dim, hidden_dim=256, dropout_rate=0.29, activation='Swish') model.load_state_dict(torch.load(path, map_location=device)) model.to(device) model.eval() return model