|
| 1 | +from typing import Dict, List, Tuple |
| 2 | + |
| 3 | +import networkx as nx |
| 4 | +import torch |
| 5 | +import torch_geometric.data as gd |
| 6 | +from networkx.algorithms.isomorphism import is_isomorphic as nx_is_isomorphic |
| 7 | + |
| 8 | +from gflownet.envs.graph_building_env import ( |
| 9 | + Graph, |
| 10 | + GraphAction, |
| 11 | + GraphActionType, |
| 12 | + GraphBuildingEnvContext, |
| 13 | + graph_without_edge, |
| 14 | +) |
| 15 | +from gflownet.utils.graphs import random_walk_probs |
| 16 | + |
| 17 | + |
| 18 | +def hashg(g): |
| 19 | + return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v") |
| 20 | + |
| 21 | + |
| 22 | +def is_isomorphic(u, v): |
| 23 | + return nx_is_isomorphic(u, v, lambda a, b: a == b, lambda a, b: a == b) |
| 24 | + |
| 25 | + |
| 26 | +class BasicGraphContext(GraphBuildingEnvContext): |
| 27 | + """ |
| 28 | + A basic graph generation context. |
| 29 | +
|
| 30 | + This simple environment context is designed to be used to test implementations. It only allows for AddNode and |
| 31 | + AddEdge actions, and is meant to be used within the BasicGraphTask to generate graphs of up to 7 nodes with |
| 32 | + only two possible node attributes, making the state space a total of ~70k states (which is nicely enumerable |
| 33 | + and allows us to compute p_theta(x) exactly for all x in the state space). |
| 34 | + """ |
| 35 | + |
| 36 | + def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False): |
| 37 | + self.max_nodes = max_nodes |
| 38 | + self.output_gid = output_gid |
| 39 | + |
| 40 | + self.node_attr_values = { |
| 41 | + "v": [0, 1], # Imagine this is as colors |
| 42 | + } |
| 43 | + self._num_rw_feat = 8 |
| 44 | + |
| 45 | + self.num_new_node_values = len(self.node_attr_values["v"]) |
| 46 | + self.num_node_attr_logits = None |
| 47 | + self.num_node_dim = self.num_new_node_values + 1 + self._num_rw_feat |
| 48 | + self.num_node_attrs = 1 |
| 49 | + self.num_edge_attr_logits = None |
| 50 | + self.num_edge_attrs = 0 |
| 51 | + self.num_cond_dim = num_cond_dim |
| 52 | + self.num_edge_dim = 1 |
| 53 | + self.edges_are_duplicated = True |
| 54 | + self.edges_are_unordered = True |
| 55 | + |
| 56 | + # Order in which models have to output logits |
| 57 | + self.action_type_order = [ |
| 58 | + GraphActionType.Stop, |
| 59 | + GraphActionType.AddNode, |
| 60 | + GraphActionType.AddEdge, |
| 61 | + ] |
| 62 | + self.bck_action_type_order = [ |
| 63 | + GraphActionType.RemoveNode, |
| 64 | + GraphActionType.RemoveEdge, |
| 65 | + ] |
| 66 | + self.device = torch.device("cpu") |
| 67 | + self.graph_data = graph_data |
| 68 | + self.hash_to_graphs: Dict[str, int] = {} |
| 69 | + if graph_data is not None: |
| 70 | + states_hash = [hashg(i) for i in graph_data] |
| 71 | + for i, h, g in zip(range(len(graph_data)), states_hash, graph_data): |
| 72 | + self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)] |
| 73 | + |
| 74 | + def get_graph_idx(self, g, default=None): |
| 75 | + h = hashg(g) |
| 76 | + if h not in self.hash_to_graphs and default is not None: |
| 77 | + return default |
| 78 | + bucket = self.hash_to_graphs[h] |
| 79 | + if len(bucket) == 1: |
| 80 | + return bucket[0][1] |
| 81 | + for i in bucket: |
| 82 | + if is_isomorphic(i[0], g): |
| 83 | + return i[1] |
| 84 | + if default is not None: |
| 85 | + return default |
| 86 | + raise ValueError(g) |
| 87 | + |
| 88 | + def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): |
| 89 | + """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" |
| 90 | + act_type, act_row, act_col = [int(i) for i in action_idx] |
| 91 | + if fwd: |
| 92 | + t = self.action_type_order[act_type] |
| 93 | + else: |
| 94 | + t = self.bck_action_type_order[act_type] |
| 95 | + |
| 96 | + if t is GraphActionType.Stop: |
| 97 | + return GraphAction(t) |
| 98 | + elif t is GraphActionType.AddNode: |
| 99 | + return GraphAction(t, source=act_row, value=self.node_attr_values["v"][act_col]) |
| 100 | + elif t is GraphActionType.AddEdge: |
| 101 | + a, b = g.non_edge_index[:, act_row] |
| 102 | + return GraphAction(t, source=a.item(), target=b.item()) |
| 103 | + elif t is GraphActionType.RemoveNode: |
| 104 | + return GraphAction(t, source=act_row) |
| 105 | + elif t is GraphActionType.RemoveEdge: |
| 106 | + a, b = g.edge_index[:, act_row * 2] |
| 107 | + return GraphAction(t, source=a.item(), target=b.item()) |
| 108 | + |
| 109 | + def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: |
| 110 | + """Translate a GraphAction to an index tuple""" |
| 111 | + if action.action is GraphActionType.Stop: |
| 112 | + row = col = 0 |
| 113 | + type_idx = self.action_type_order.index(action.action) |
| 114 | + elif action.action is GraphActionType.AddNode: |
| 115 | + row = action.source |
| 116 | + col = self.node_attr_values["v"].index(action.value) |
| 117 | + type_idx = self.action_type_order.index(action.action) |
| 118 | + elif action.action is GraphActionType.AddEdge: |
| 119 | + # Here we have to retrieve the index in non_edge_index of an edge (s,t) |
| 120 | + # that's also possibly in the reverse order (t,s). |
| 121 | + # That's definitely not too efficient, can we do better? |
| 122 | + row = ( |
| 123 | + (g.non_edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) |
| 124 | + + (g.non_edge_index.T == torch.tensor([(action.target, action.source)])).prod(1) |
| 125 | + ).argmax() |
| 126 | + col = 0 |
| 127 | + type_idx = self.action_type_order.index(action.action) |
| 128 | + elif action.action is GraphActionType.RemoveNode: |
| 129 | + row = action.source |
| 130 | + col = 0 |
| 131 | + type_idx = self.bck_action_type_order.index(action.action) |
| 132 | + elif action.action is GraphActionType.RemoveEdge: |
| 133 | + row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax() |
| 134 | + row = int(row) // 2 # edges are duplicated, but edge logits are not |
| 135 | + col = 0 |
| 136 | + type_idx = self.bck_action_type_order.index(action.action) |
| 137 | + return (type_idx, int(row), int(col)) |
| 138 | + |
| 139 | + def graph_to_Data(self, g: Graph) -> gd.Data: |
| 140 | + """Convert a networkx Graph to a torch geometric Data instance""" |
| 141 | + x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat)) |
| 142 | + x[0, -1] = len(g.nodes) == 0 |
| 143 | + remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) |
| 144 | + for i, n in enumerate(g.nodes): |
| 145 | + ad = g.nodes[n] |
| 146 | + x[i, self.node_attr_values["v"].index(ad["v"])] = 1 |
| 147 | + if g.degree(n) <= 1: |
| 148 | + remove_node_mask[i] = 1 |
| 149 | + |
| 150 | + remove_edge_mask = torch.zeros((len(g.edges), 1)) |
| 151 | + for i, (u, v) in enumerate(g.edges): |
| 152 | + if g.degree(u) > 1 and g.degree(v) > 1: |
| 153 | + if nx.algorithms.is_connected(graph_without_edge(g, (u, v))): |
| 154 | + remove_edge_mask[i] = 1 |
| 155 | + edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) |
| 156 | + edge_index = ( |
| 157 | + torch.tensor([e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long).reshape((-1, 2)).T |
| 158 | + ) |
| 159 | + gc = nx.complement(g) |
| 160 | + non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T |
| 161 | + gid = self.get_graph_idx(g) if self.output_gid else 0 |
| 162 | + |
| 163 | + return self._preprocess( |
| 164 | + gd.Data( |
| 165 | + x, |
| 166 | + edge_index, |
| 167 | + edge_attr, |
| 168 | + non_edge_index=non_edge_index, |
| 169 | + stop_mask=torch.ones((1, 1)), |
| 170 | + add_node_mask=torch.ones((x.shape[0], self.num_new_node_values)) * (len(g) < self.max_nodes), |
| 171 | + add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), |
| 172 | + remove_node_mask=remove_node_mask, |
| 173 | + remove_edge_mask=remove_edge_mask, |
| 174 | + gid=gid, |
| 175 | + ) |
| 176 | + ) |
| 177 | + |
| 178 | + def _preprocess(self, g: gd.Data) -> gd.Data: |
| 179 | + if self._num_rw_feat > 0: |
| 180 | + g.x = torch.cat([g.x, random_walk_probs(g, self._num_rw_feat, skip_odd=True)], 1) |
| 181 | + return g |
| 182 | + |
| 183 | + def collate(self, graphs: List[gd.Data]): |
| 184 | + """Batch Data instances""" |
| 185 | + return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"]) |
| 186 | + |
| 187 | + def mol_to_graph(self, obj: Graph) -> Graph: |
| 188 | + return obj # This is already a graph |
| 189 | + |
| 190 | + def graph_to_mol(self, g: Graph) -> Graph: |
| 191 | + # idem |
| 192 | + return g |
| 193 | + |
| 194 | + def is_sane(self, g: Graph) -> bool: |
| 195 | + return True |
| 196 | + |
| 197 | + def get_object_description(self, g: Graph, is_valid: bool) -> str: |
| 198 | + return str(self.get_graph_idx(g, -1)) |
0 commit comments