Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torch.distributions.categorical import Categorical | |
| from src.egnn import GCL | |
| class DistributionNodes: | |
| def __init__(self, histogram): | |
| self.n_nodes = [] | |
| prob = [] | |
| self.keys = {} | |
| for i, nodes in enumerate(histogram): | |
| self.n_nodes.append(nodes) | |
| self.keys[nodes] = i | |
| prob.append(histogram[nodes]) | |
| self.n_nodes = torch.tensor(self.n_nodes) | |
| prob = np.array(prob) | |
| prob = prob/np.sum(prob) | |
| self.prob = torch.from_numpy(prob).float() | |
| self.m = Categorical(torch.tensor(prob)) | |
| def sample(self, n_samples=1): | |
| idx = self.m.sample((n_samples,)) | |
| return self.n_nodes[idx] | |
| def log_prob(self, batch_n_nodes): | |
| assert len(batch_n_nodes.size()) == 1 | |
| idcs = [self.keys[i.item()] for i in batch_n_nodes] | |
| idcs = torch.tensor(idcs).to(batch_n_nodes.device) | |
| log_p = torch.log(self.prob + 1e-30) | |
| log_p = log_p.to(batch_n_nodes.device) | |
| log_probs = log_p[idcs] | |
| return log_probs | |
| class SizeGNN(nn.Module): | |
| def __init__(self, in_node_nf, hidden_nf, out_node_nf, n_layers, normalization, device='cpu'): | |
| super(SizeGNN, self).__init__() | |
| self.hidden_nf = hidden_nf | |
| self.out_node_nf = out_node_nf | |
| self.device = device | |
| self.embedding_in = nn.Linear(in_node_nf, self.hidden_nf) | |
| self.gcl1 = GCL( | |
| input_nf=self.hidden_nf, | |
| output_nf=self.hidden_nf, | |
| hidden_nf=self.hidden_nf, | |
| normalization_factor=1, | |
| aggregation_method='sum', | |
| edges_in_d=1, | |
| activation=nn.ReLU(), | |
| attention=False, | |
| normalization=normalization | |
| ) | |
| layers = [] | |
| for i in range(n_layers - 1): | |
| layer = GCL( | |
| input_nf=self.hidden_nf, | |
| output_nf=self.hidden_nf, | |
| hidden_nf=self.hidden_nf, | |
| normalization_factor=1, | |
| aggregation_method='sum', | |
| edges_in_d=1, | |
| activation=nn.ReLU(), | |
| attention=False, | |
| normalization=normalization | |
| ) | |
| layers.append(layer) | |
| self.gcl_layers = nn.ModuleList(layers) | |
| self.embedding_out = nn.Linear(self.hidden_nf, self.out_node_nf) | |
| self.to(self.device) | |
| def forward(self, h, edges, distances, node_mask, edge_mask): | |
| h = self.embedding_in(h) | |
| h, _ = self.gcl1(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask) | |
| for gcl in self.gcl_layers: | |
| h, _ = gcl(h, edges, edge_attr=distances, node_mask=node_mask, edge_mask=edge_mask) | |
| h = self.embedding_out(h) | |
| return h | |