Spaces:
Sleeping
Sleeping
| import hydra | |
| import torch | |
| import torch.nn.functional as F | |
| from einops.layers.torch import Rearrange | |
| from utils.pooling import HomogeneousAggregator | |
| import torch.nn as nn | |
| class RelationalTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| d_node, | |
| d_edge, | |
| d_attn_hid, | |
| d_node_hid, | |
| d_edge_hid, | |
| d_out_hid, | |
| d_out, | |
| n_layers, | |
| n_heads, | |
| layer_layout, | |
| graph_constructor, | |
| dropout=0.0, | |
| node_update_type="rt", | |
| disable_edge_updates=False, | |
| use_cls_token=False, | |
| pooling_method="cat", | |
| pooling_layer_idx="last", | |
| rev_edge_features=False, | |
| modulate_v=True, | |
| use_ln=True, | |
| tfixit_init=False, | |
| ): | |
| super().__init__() | |
| assert use_cls_token == (pooling_method == "cls_token") | |
| self.pooling_method = pooling_method | |
| self.pooling_layer_idx = pooling_layer_idx | |
| self.rev_edge_features = rev_edge_features | |
| self.nodes_per_layer = layer_layout | |
| self.construct_graph = hydra.utils.instantiate( | |
| graph_constructor, | |
| d_node=d_node, | |
| d_edge=d_edge, | |
| layer_layout=layer_layout, | |
| rev_edge_features=rev_edge_features, | |
| ) | |
| self.use_cls_token = use_cls_token | |
| if use_cls_token: | |
| self.cls_token = nn.Parameter(torch.randn(d_node)) | |
| self.layers = nn.ModuleList( | |
| [ | |
| torch.jit.script( | |
| RTLayer( | |
| d_node, | |
| d_edge, | |
| d_attn_hid, | |
| d_node_hid, | |
| d_edge_hid, | |
| n_heads, | |
| dropout, | |
| node_update_type=node_update_type, | |
| disable_edge_updates=( | |
| (disable_edge_updates or (i == n_layers - 1)) | |
| and pooling_method != "mean_edge" | |
| and pooling_layer_idx != "all" | |
| ), | |
| modulate_v=modulate_v, | |
| use_ln=use_ln, | |
| tfixit_init=tfixit_init, | |
| n_layers=n_layers, | |
| ) | |
| ) | |
| for i in range(n_layers) | |
| ] | |
| ) | |
| if pooling_method != "cls_token": | |
| self.pool = HomogeneousAggregator( | |
| pooling_method, | |
| pooling_layer_idx, | |
| layer_layout, | |
| ) | |
| self.num_graph_features = ( | |
| layer_layout[-1] * d_node | |
| if pooling_method == "cat" and pooling_layer_idx == "last" | |
| else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node | |
| ) | |
| self.proj_out = nn.Sequential( | |
| nn.Linear(self.num_graph_features, d_out_hid), | |
| nn.ReLU(), | |
| # nn.Linear(d_out_hid, d_out_hid), | |
| # nn.ReLU(), | |
| nn.Linear(d_out_hid, d_out), | |
| ) | |
| self.final_features = (None,None,None,None) | |
| def forward(self, inputs): | |
| attn_weights = None | |
| node_features, edge_features, mask = self.construct_graph(inputs) | |
| if self.use_cls_token: | |
| node_features = torch.cat( | |
| [ | |
| # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)), | |
| self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1), | |
| node_features, | |
| ], | |
| dim=1, | |
| ) | |
| edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0) | |
| for layer in self.layers: | |
| node_features, edge_features, attn_weights = layer(node_features, edge_features, mask) | |
| if self.pooling_method == "cls_token": | |
| graph_features = node_features[:, 0] | |
| else: | |
| graph_features = self.pool(node_features, edge_features) | |
| self.final_features = (graph_features, node_features, edge_features, attn_weights) | |
| return self.proj_out(graph_features) | |
| class RTLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_node, | |
| d_edge, | |
| d_attn_hid, | |
| d_node_hid, | |
| d_edge_hid, | |
| n_heads, | |
| dropout, | |
| node_update_type="rt", | |
| disable_edge_updates=False, | |
| modulate_v=True, | |
| use_ln=True, | |
| tfixit_init=False, | |
| n_layers=None, | |
| ): | |
| super().__init__() | |
| self.node_update_type = node_update_type | |
| self.disable_edge_updates = disable_edge_updates | |
| self.use_ln = use_ln | |
| self.n_layers = n_layers | |
| self.self_attn = torch.jit.script( | |
| RTAttention( | |
| d_node, | |
| d_edge, | |
| d_attn_hid, | |
| n_heads, | |
| modulate_v=modulate_v, | |
| use_ln=use_ln, | |
| ) | |
| ) | |
| # self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads) | |
| self.lin0 = Linear(d_node, d_node) | |
| self.dropout0 = nn.Dropout(dropout) | |
| if use_ln: | |
| self.node_ln0 = nn.LayerNorm(d_node) | |
| self.node_ln1 = nn.LayerNorm(d_node) | |
| else: | |
| self.node_ln0 = nn.Identity() | |
| self.node_ln1 = nn.Identity() | |
| act_fn = nn.GELU | |
| self.node_mlp = nn.Sequential( | |
| Linear(d_node, d_node_hid, bias=False), | |
| act_fn(), | |
| Linear(d_node_hid, d_node), | |
| nn.Dropout(dropout), | |
| ) | |
| if not self.disable_edge_updates: | |
| self.edge_updates = EdgeLayer( | |
| d_node=d_node, | |
| d_edge=d_edge, | |
| d_edge_hid=d_edge_hid, | |
| dropout=dropout, | |
| act_fn=act_fn, | |
| use_ln=use_ln, | |
| ) | |
| else: | |
| self.edge_updates = NoEdgeLayer() | |
| if tfixit_init: | |
| self.fixit_init() | |
| def fixit_init(self): | |
| temp_state_dict = self.state_dict() | |
| n_layers = self.n_layers | |
| for name, param in self.named_parameters(): | |
| if "weight" in name: | |
| if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]: | |
| temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param | |
| elif name.split(".")[0] in ["self_attn"]: | |
| temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * ( | |
| param * (2**0.5) | |
| ) | |
| self.load_state_dict(temp_state_dict) | |
| def node_updates(self, node_features, edge_features, mask): | |
| out = self.self_attn(node_features, edge_features, mask) | |
| attn_out, attn_weights = out | |
| node_features = self.node_ln0( | |
| node_features | |
| + self.dropout0( | |
| self.lin0(attn_out) | |
| ) | |
| ) | |
| node_features = self.node_ln1(node_features + self.node_mlp(node_features)) | |
| return node_features, attn_weights | |
| def forward(self, node_features, edge_features, mask): | |
| node_features, attn_weights = self.node_updates(node_features, edge_features, mask) | |
| edge_features = self.edge_updates(node_features, edge_features, mask) | |
| return node_features, edge_features, attn_weights | |
| class EdgeLayer(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| d_node, | |
| d_edge, | |
| d_edge_hid, | |
| dropout, | |
| act_fn, | |
| use_ln=True, | |
| ) -> None: | |
| super().__init__() | |
| self.edge_mlp0 = EdgeMLP( | |
| d_edge=d_edge, | |
| d_node=d_node, | |
| d_edge_hid=d_edge_hid, | |
| act_fn=act_fn, | |
| dropout=dropout, | |
| ) | |
| self.edge_mlp1 = nn.Sequential( | |
| Linear(d_edge, d_edge_hid, bias=False), | |
| act_fn(), | |
| Linear(d_edge_hid, d_edge), | |
| nn.Dropout(dropout), | |
| ) | |
| if use_ln: | |
| self.eln0 = nn.LayerNorm(d_edge) | |
| self.eln1 = nn.LayerNorm(d_edge) | |
| else: | |
| self.eln0 = nn.Identity() | |
| self.eln1 = nn.Identity() | |
| def forward(self, node_features, edge_features, mask): | |
| edge_features = self.eln0( | |
| edge_features + self.edge_mlp0(node_features, edge_features) | |
| ) | |
| edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features)) | |
| return edge_features | |
| class NoEdgeLayer(nn.Module): | |
| def forward(self, node_features, edge_features, mask): | |
| return edge_features | |
| class EdgeMLP(nn.Module): | |
| def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout): | |
| super().__init__() | |
| self.reverse_edge = Rearrange("b n m d -> b m n d") | |
| self.lin0_e = Linear(2 * d_edge, d_edge_hid) | |
| self.lin0_s = Linear(d_node, d_edge_hid) | |
| self.lin0_t = Linear(d_node, d_edge_hid) | |
| self.act = act_fn() | |
| self.lin1 = Linear(d_edge_hid, d_edge) | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, node_features, edge_features): | |
| source_nodes = ( | |
| self.lin0_s(node_features) | |
| .unsqueeze(-2) | |
| .expand(-1, -1, node_features.size(-2), -1) | |
| ) | |
| target_nodes = ( | |
| self.lin0_t(node_features) | |
| .unsqueeze(-3) | |
| .expand(-1, node_features.size(-2), -1, -1) | |
| ) | |
| # reversed_edge_features = self.reverse_edge(edge_features) | |
| edge_features = self.lin0_e( | |
| torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1) | |
| ) | |
| edge_features = edge_features + source_nodes + target_nodes | |
| edge_features = self.act(edge_features) | |
| edge_features = self.lin1(edge_features) | |
| edge_features = self.drop(edge_features) | |
| return edge_features | |
| class RTAttention(nn.Module): | |
| def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True): | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.d_node = d_node | |
| self.d_edge = d_edge | |
| self.d_hid = d_hid | |
| self.use_ln = use_ln | |
| self.modulate_v = modulate_v | |
| self.scale = 1 / (d_hid**0.5) | |
| self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads) | |
| self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads) | |
| self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads) | |
| self.qkv_node = Linear(d_node, 3 * d_hid, bias=False) | |
| self.edge_factor = 4 if modulate_v else 3 | |
| self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False) | |
| self.proj_out = Linear(d_hid, d_node) | |
| def forward(self, node_features, edge_features, mask): | |
| qkv_node = self.qkv_node(node_features) | |
| # qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads) | |
| qkv_node = self.split_head_node(qkv_node) | |
| q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1) | |
| qkv_edge = self.qkv_edge(edge_features) | |
| # qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads) | |
| qkv_edge = self.split_head_edge(qkv_edge) | |
| qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1) | |
| # q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk( | |
| # qkv_edge, 6, dim=-1 | |
| # ) | |
| # qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge] | |
| q = q_node.unsqueeze(-2) + qkv_edge[0] # + q_edge_b | |
| k = k_node.unsqueeze(-3) + qkv_edge[1] # + k_edge_b | |
| if self.modulate_v: | |
| v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2] | |
| else: | |
| v = v_node.unsqueeze(-3) + qkv_edge[2] | |
| dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k) | |
| # dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9) | |
| attn = F.softmax(dots, dim=-1) | |
| out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v) | |
| out = self.cat_head_node(out) | |
| return self.proj_out(out), attn | |
| def Linear(in_features, out_features, bias=True): | |
| m = nn.Linear(in_features, out_features, bias) | |
| nn.init.xavier_uniform_(m.weight) # , gain=1 / math.sqrt(2)) | |
| if bias: | |
| nn.init.constant_(m.bias, 0.0) | |
| return m | |