Skip to content

Commit e6b7906

Browse files
Add backward sampling (#110)
* recursive decompose routine for fragment env * backward sampling + some extra stuff in seh * merge some stuff from feat_two_models * fixed bugs & tested * tox + Mohit authorship Co-authored-by: Mohit Pandey <pandey.mohitk@gmail.com> * ruff & mypy fixes * minor elif fix * reset to default * add pyg pins --------- Co-authored-by: Mohit Pandey <pandey.mohitk@gmail.com>
1 parent 2f9d43b commit e6b7906

File tree

12 files changed

+427
-74
lines changed

12 files changed

+427
-74
lines changed

docs/examples/grid_cond_gfn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def sample_many(self, mbsize):
253253
return log_ratio
254254

255255
def learn_from(self, it, batch):
256-
if type(batch) is list:
256+
if isinstance(batch, list):
257257
log_ratio = torch.stack(batch, 0)
258258
else:
259259
log_ratio = batch

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ requires-python = ">=3.8,<3.10"
5757
dynamic = ["version"]
5858
dependencies = [
5959
"torch==1.13.1",
60-
"torch-geometric",
61-
"torch-scatter",
62-
"torch-sparse",
63-
"torch-cluster",
60+
"torch-geometric==2.3.1", # Pinning until we adapt the code to newer versions
61+
"torch-scatter==2.1.1",
62+
"torch-sparse==0.6.17",
63+
"torch-cluster==1.6.1",
6464
"rdkit",
6565
"tables",
6666
"scipy",

src/gflownet/algo/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TBConfig:
4545
variant: TBVariant = TBVariant.TB
4646
do_correct_idempotent: bool = False
4747
do_parameterize_p_b: bool = False
48+
do_sample_p_b: bool = False
4849
do_length_normalize: bool = False
4950
subtb_max_len: int = 128
5051
Z_learning_rate: float = 1e-4

src/gflownet/algo/graph_sampling.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
import copy
2-
from typing import List
2+
from typing import List, Optional
33

44
import torch
55
import torch.nn as nn
66
from torch import Tensor
77

8-
from gflownet.envs.graph_building_env import GraphAction, GraphActionType
8+
from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType
9+
from gflownet.models.graph_transformer import GraphTransformerGFN
10+
11+
12+
def relabel(g: Graph, ga: GraphAction):
13+
"""Relabel the nodes for g to 0-N, and the graph action ga applied to g.
14+
This is necessary because torch_geometric and EnvironmentContext classes expect nodes to be
15+
labeled 0-N, whereas GraphBuildingEnv.parent can return parents with e.g. a removed node that
16+
creates a gap in 0-N, leading to a faulty encoding of the graph.
17+
"""
18+
rmap = dict(zip(g.nodes, range(len(g.nodes))))
19+
if not len(g) and ga.action == GraphActionType.AddNode:
20+
rmap[0] = 0 # AddNode can add to the empty graph, the source is still 0
21+
g = g.relabel_nodes(rmap)
22+
if ga.source is not None:
23+
ga.source = rmap[ga.source]
24+
if ga.target is not None:
25+
ga.target = rmap[ga.target]
26+
return g, ga
927

1028

1129
class GraphSampler:
@@ -185,3 +203,99 @@ def not_done(lst):
185203
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop)))
186204
data[i]["is_sink"].append(1)
187205
return data
206+
207+
def sample_backward_from_graphs(
208+
self,
209+
graphs: List[Graph],
210+
model: Optional[nn.Module],
211+
cond_info: Tensor,
212+
dev: torch.device,
213+
random_action_prob: float = 0.0,
214+
):
215+
"""Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution
216+
over legal actions.
217+
218+
Parameters
219+
----------
220+
graphs: List[Graph]
221+
List of Graph endpoints
222+
model: nn.Module
223+
Model whose forward() method returns GraphActionCategorical instances
224+
cond_info: Tensor
225+
Conditional information of each trajectory, shape (n, n_info)
226+
dev: torch.device
227+
Device on which data is manipulated
228+
random_action_prob: float
229+
Probability of taking a random action (only used if model parameterizes P_B)
230+
231+
"""
232+
n = len(graphs)
233+
done = [False] * n
234+
data = [
235+
{
236+
"traj": [(graphs[i], GraphAction(GraphActionType.Stop))],
237+
"is_valid": True,
238+
"is_sink": [1],
239+
"bck_a": [GraphAction(GraphActionType.Stop)],
240+
"bck_logprobs": [0.0],
241+
"result": graphs[i],
242+
}
243+
for i in range(n)
244+
]
245+
246+
def not_done(lst):
247+
return [e for i, e in enumerate(lst) if not done[i]]
248+
249+
if random_action_prob > 0:
250+
raise NotImplementedError("Random action not implemented for backward sampling")
251+
252+
while sum(done) < n:
253+
torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))]
254+
not_done_mask = torch.tensor(done, device=dev).logical_not()
255+
if model is not None:
256+
_, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
257+
else:
258+
gbatch = self.ctx.collate(torch_graphs)
259+
action_types = self.ctx.bck_action_type_order
260+
masks = [getattr(gbatch, i.mask_name) for i in action_types]
261+
bck_cat = GraphActionCategorical(
262+
gbatch,
263+
logits=[m * 1e6 for m in masks],
264+
keys=[
265+
# TODO: This is not very clean, could probably abstract this away somehow
266+
GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]]
267+
for t in action_types
268+
],
269+
masks=masks,
270+
types=action_types,
271+
)
272+
bck_actions = bck_cat.sample()
273+
graph_bck_actions = [
274+
self.ctx.aidx_to_GraphAction(g, a, fwd=False) for g, a in zip(torch_graphs, bck_actions)
275+
]
276+
bck_logprobs = bck_cat.log_prob(bck_actions)
277+
278+
for i, j in zip(not_done(range(n)), range(n)):
279+
if not done[i]:
280+
g = graphs[i]
281+
b_a = graph_bck_actions[j]
282+
gp = self.env.step(g, b_a)
283+
f_a = self.env.reverse(g, b_a)
284+
graphs[i], f_a = relabel(gp, f_a)
285+
data[i]["traj"].append((graphs[i], f_a))
286+
data[i]["bck_a"].append(b_a)
287+
data[i]["is_sink"].append(0)
288+
data[i]["bck_logprobs"].append(bck_logprobs[j].item())
289+
if len(graphs[i]) == 0:
290+
done[i] = True
291+
292+
for i in range(n):
293+
# See comments in sample_from_model
294+
data[i]["traj"] = data[i]["traj"][::-1]
295+
data[i]["bck_a"] = [GraphAction(GraphActionType.Stop)] + data[i]["bck_a"][::-1]
296+
data[i]["is_sink"] = data[i]["is_sink"][::-1]
297+
data[i]["bck_logprobs"] = torch.tensor(data[i]["bck_logprobs"][::-1], device=dev).reshape(-1)
298+
if self.pad_with_terminal_state:
299+
data[i]["traj"].append((graphs[i], GraphAction(GraphActionType.Stop)))
300+
data[i]["is_sink"].append(1)
301+
return data

src/gflownet/algo/trajectory_balance.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Optional, Tuple
22

33
import networkx as nx
44
import numpy as np
@@ -148,8 +148,8 @@ def create_training_data_from_own_samples(
148148
----------
149149
model: TrajectoryBalanceModel
150150
The model being sampled
151-
graphs: List[Graph]
152-
List of N Graph endpoints
151+
n: int
152+
Number of trajectories to sample
153153
cond_info: torch.tensor
154154
Conditional information, shape (N, n_info)
155155
random_action_prob: float
@@ -174,19 +174,38 @@ def create_training_data_from_own_samples(
174174
data[i]["logZ"] = logZ_pred[i].item()
175175
return data
176176

177-
def create_training_data_from_graphs(self, graphs):
177+
def create_training_data_from_graphs(
178+
self,
179+
graphs,
180+
model: Optional[TrajectoryBalanceModel] = None,
181+
cond_info: Optional[Tensor] = None,
182+
random_action_prob: Optional[float] = None,
183+
):
178184
"""Generate trajectories from known endpoints
179185
180186
Parameters
181187
----------
182188
graphs: List[Graph]
183189
List of Graph endpoints
190+
model: TrajectoryBalanceModel
191+
The model being sampled
192+
cond_info: torch.tensor
193+
Conditional information, shape (N, n_info)
194+
random_action_prob: float
195+
Probability of taking a random action
184196
185197
Returns
186198
-------
187199
trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}]
188200
A list of trajectories.
189201
"""
202+
if self.cfg.do_sample_p_b:
203+
assert model is not None and cond_info is not None and random_action_prob is not None
204+
dev = self.ctx.device
205+
cond_info = cond_info.to(dev)
206+
return self.graph_sampler.sample_backward_from_graphs(
207+
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob
208+
)
190209
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
191210
for traj in trajs:
192211
n_back = [

src/gflownet/data/sampling_iterator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,13 @@ def __iter__(self):
177177
)
178178

179179
# Sample some dataset data
180-
mols, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], [])
180+
graphs, flat_rewards = map(list, zip(*[self.data[i] for i in idcs])) if len(idcs) else ([], [])
181181
flat_rewards = (
182182
list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else []
183183
)
184-
graphs = [self.ctx.mol_to_graph(m) for m in mols]
185-
trajs = self.algo.create_training_data_from_graphs(graphs)
184+
trajs = self.algo.create_training_data_from_graphs(
185+
graphs, self.model, cond_info["encoding"][:num_offline], 0
186+
)
186187

187188
else: # If we're not sampling the conditionals, then the idcs refer to listed preferences
188189
num_online = num_offline
@@ -411,7 +412,9 @@ def _make_results_table(self, types, names):
411412
cur.close()
412413

413414
def insert_many(self, rows, column_names):
414-
assert all([type(x) is str or not isinstance(x, Iterable) for x in rows[0]]), "rows must only contain scalars"
415+
assert all(
416+
[isinstance(x, str) or not isinstance(x, Iterable) for x in rows[0]]
417+
), "rows must only contain scalars"
415418
if not self._has_results_table:
416419
self._make_results_table([type(i) for i in rows[0]], column_names)
417420
cur = self.db.cursor()

0 commit comments

Comments
 (0)