|
1 | 1 | import copy
|
2 |
| -from typing import List |
| 2 | +from typing import List, Optional |
3 | 3 |
|
4 | 4 | import torch
|
5 | 5 | import torch.nn as nn
|
6 | 6 | from torch import Tensor
|
7 | 7 |
|
8 |
| -from gflownet.envs.graph_building_env import GraphAction, GraphActionType |
| 8 | +from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType |
| 9 | +from gflownet.models.graph_transformer import GraphTransformerGFN |
| 10 | + |
| 11 | + |
| 12 | +def relabel(g: Graph, ga: GraphAction): |
| 13 | + """Relabel the nodes for g to 0-N, and the graph action ga applied to g. |
| 14 | + This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be |
| 15 | + labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that |
| 16 | + creates a gap in 0-N, leading to a faulty encoding of the graph. |
| 17 | + """ |
| 18 | + rmap = dict(zip(g.nodes, range(len(g.nodes)))) |
| 19 | + if not len(g) and ga.action == GraphActionType.AddNode: |
| 20 | + rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0 |
| 21 | + g = g.relabel_nodes(rmap) |
| 22 | + if ga.source is not None: |
| 23 | + ga.source = rmap[ga.source] |
| 24 | + if ga.target is not None: |
| 25 | + ga.target = rmap[ga.target] |
| 26 | + return g, ga |
9 | 27 |
|
10 | 28 |
|
11 | 29 | class GraphSampler:
|
@@ -185,3 +203,99 @@ def not_done(lst):
|
185 | 203 | data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop)))
|
186 | 204 | data[i]["is_sink"].append(1)
|
187 | 205 | return data
|
| 206 | + |
| 207 | + def sample_backward_from_graphs( |
| 208 | + self, |
| 209 | + graphs: List[Graph], |
| 210 | + model: Optional[nn.Module], |
| 211 | + cond_info: Tensor, |
| 212 | + dev: torch.device, |
| 213 | + random_action_prob: float = 0.0, |
| 214 | + ): |
| 215 | + """Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution |
| 216 | + over legal actions. |
| 217 | +
|
| 218 | + Parameters |
| 219 | + ---------- |
| 220 | + graphs: List[Graph] |
| 221 | + List of Graph endpoints |
| 222 | + model: nn.Module |
| 223 | + Model whose forward() method returns GraphActionCategorical instances |
| 224 | + cond_info: Tensor |
| 225 | + Conditional information of each trajectory, shape (n, n_info) |
| 226 | + dev: torch.device |
| 227 | + Device on which data is manipulated |
| 228 | + random_action_prob: float |
| 229 | + Probability of taking a random action (only used if model parameterizes P_B) |
| 230 | +
|
| 231 | + """ |
| 232 | + n = len(graphs) |
| 233 | + done = [False] * n |
| 234 | + data = [ |
| 235 | + { |
| 236 | + "traj": [(graphs[i], GraphAction(GraphActionType.Stop))], |
| 237 | + "is_valid": True, |
| 238 | + "is_sink": [1], |
| 239 | + "bck_a": [GraphAction(GraphActionType.Stop)], |
| 240 | + "bck_logprobs": [0.0], |
| 241 | + "result": graphs[i], |
| 242 | + } |
| 243 | + for i in range(n) |
| 244 | + ] |
| 245 | + |
| 246 | + def not_done(lst): |
| 247 | + return [e for i, e in enumerate(lst) if not done[i]] |
| 248 | + |
| 249 | + if random_action_prob > 0: |
| 250 | + raise NotImplementedError("Random action not implemented for backward sampling") |
| 251 | + |
| 252 | + while sum(done) < n: |
| 253 | + torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] |
| 254 | + not_done_mask = torch.tensor(done, device=dev).logical_not() |
| 255 | + if model is not None: |
| 256 | + _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) |
| 257 | + else: |
| 258 | + gbatch = self.ctx.collate(torch_graphs) |
| 259 | + action_types = self.ctx.bck_action_type_order |
| 260 | + masks = [getattr(gbatch, i.mask_name) for i in action_types] |
| 261 | + bck_cat = GraphActionCategorical( |
| 262 | + gbatch, |
| 263 | + logits=[m * 1e6 for m in masks], |
| 264 | + keys=[ |
| 265 | + # TODO: This is not very clean, could probably abstract this away somehow |
| 266 | + GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]] |
| 267 | + for t in action_types |
| 268 | + ], |
| 269 | + masks=masks, |
| 270 | + types=action_types, |
| 271 | + ) |
| 272 | + bck_actions = bck_cat.sample() |
| 273 | + graph_bck_actions = [ |
| 274 | + self.ctx.aidx_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions) |
| 275 | + ] |
| 276 | + bck_logprobs = bck_cat.log_prob(bck_actions) |
| 277 | + |
| 278 | + for i, j in zip(not_done(range(n)), range(n)): |
| 279 | + if not done[i]: |
| 280 | + g = graphs[i] |
| 281 | + b_a = graph_bck_actions[j] |
| 282 | + gp = self.env.step(g, b_a) |
| 283 | + f_a = self.env.reverse(g, b_a) |
| 284 | + graphs[i], f_a = relabel(gp, f_a) |
| 285 | + data[i]["traj"].append((graphs[i], f_a)) |
| 286 | + data[i]["bck_a"].append(b_a) |
| 287 | + data[i]["is_sink"].append(0) |
| 288 | + data[i]["bck_logprobs"].append(bck_logprobs[j].item()) |
| 289 | + if len(graphs[i]) == 0: |
| 290 | + done[i] = True |
| 291 | + |
| 292 | + for i in range(n): |
| 293 | + # See comments in sample_from_model |
| 294 | + data[i]["traj"] = data[i]["traj"][::-1] |
| 295 | + data[i]["bck_a"] = [GraphAction(GraphActionType.Stop)] + data[i]["bck_a"][::-1] |
| 296 | + data[i]["is_sink"] = data[i]["is_sink"][::-1] |
| 297 | + data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1) |
| 298 | + if self.pad_with_terminal_state: |
| 299 | + data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop))) |
| 300 | + data[i]["is_sink"].append(1) |
| 301 | + return data |
0 commit comments