Skip to content

Commit cc12abc

Browse files
committed
merge & squash to refresh branch
1 parent ec857a5 commit cc12abc

12 files changed

+1175
-23
lines changed

docs/implementation_notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@ The code contains a specific categorical distribution type for graph actions, `G
3333

3434
Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.
3535

36-
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
36+
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.

src/gflownet/algo/graph_sampling.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def not_done(lst):
113113
]
114114
if self.sample_temp != 1:
115115
sample_cat = copy.copy(fwd_cat)
116-
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
116+
sample_cat.logits = [
117+
i * m / self.sample_temp - 1000 * (1 - m) for i, m in zip(fwd_cat.logits, fwd_cat.masks)
118+
]
117119
actions = sample_cat.sample()
118120
else:
119121
actions = fwd_cat.sample()

src/gflownet/algo/trajectory_balance.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn as nn
77
import torch_geometric.data as gd
88
from torch import Tensor
9-
from torch_scatter import scatter, scatter_sum
9+
from torch_scatter import scatter, scatter_sum, scatter_logsumexp
1010

1111
from gflownet.algo.graph_sampling import GraphSampler
1212
from gflownet.config import Config
@@ -309,22 +309,15 @@ def compute_batch_losses(
309309
# Indicate that the `batch` corresponding to each action is the above
310310
ip_log_prob = fwd_cat.log_prob(batch.ip_actions, batch=ip_batch_idces)
311311
# take the logsumexp (because we want to sum probabilities, not log probabilities)
312-
# TODO: numerically stable version:
313-
p = scatter(ip_log_prob.exp(), ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum")
314-
# As a (reasonable) band-aid, ignore p < 1e-30, this will prevent underflows due to
315-
# scatter(small number) = 0 on CUDA
316-
log_p_F = p.clamp(1e-30).log()
312+
log_p_F = scatter_logsumexp(ip_log_prob, ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])
317313

318314
if self.cfg.do_parameterize_p_b:
319315
# Now we repeat this but for the backward policy
320316
bck_ip_batch_idces = torch.arange(batch.bck_ip_lens.shape[0], device=dev).repeat_interleave(
321317
batch.bck_ip_lens
322318
)
323319
bck_ip_log_prob = bck_cat.log_prob(batch.bck_ip_actions, batch=bck_ip_batch_idces)
324-
bck_p = scatter(
325-
bck_ip_log_prob.exp(), bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum"
326-
)
327-
log_p_B = bck_p.clamp(1e-30).log()
320+
log_p_B = scatter_logsumexp(bck_ip_log_prob, bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])
328321
else:
329322
# Else just naively take the logprob of the actions we took
330323
log_p_F = fwd_cat.log_prob(batch.actions)
@@ -496,7 +489,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
496489
cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=dev), traj_lengths]), 0).long()
497490
total_loss = torch.zeros(num_trajs, device=dev)
498491
ar = torch.arange(max_len, device=dev)
499-
car = torch.cumsum(ar, 0)
492+
car = torch.cumsum(ar, 0) if self.length_normalize_losses else torch.ones_like(ar)
500493
F_and_R = torch.cat([F, R])
501494
R_start = F.shape[0]
502495
for ep in range(traj_lengths.shape[0]):

src/gflownet/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class Config:
8181
"""
8282

8383
log_dir: str = MISSING
84+
log_sampled_data: bool = True
8485
device: str = "cuda"
8586
seed: int = 0
8687
validate_every: int = 1000

src/gflownet/data/sampling_iterator.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sqlite3
3+
import traceback
34
from collections.abc import Iterable
45
from copy import deepcopy
56
from typing import Callable, List
@@ -11,6 +12,7 @@
1112
from rdkit import Chem, RDLogger
1213
from torch.utils.data import Dataset, IterableDataset
1314

15+
from gflownet.config import Config
1416
from gflownet.data.replay_buffer import ReplayBuffer
1517
from gflownet.envs.graph_building_env import GraphActionCategorical
1618

@@ -112,9 +114,14 @@ def __init__(
112114
# don't want to initialize per-worker things just yet, such as where the log the worker writes
113115
# to. This must be done in __iter__, which is called by the DataLoader once this instance
114116
# has been copied into a new python process.
115-
self.log_dir = log_dir
117+
import warnings
118+
119+
warnings.warn("Fix dependency on cfg.log_sampled_data")
120+
self.log_dir = log_dir # if cfg.log_sampled_data else None
116121
self.log = SQLiteLog()
117122
self.log_hooks: List[Callable] = []
123+
# TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db
124+
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") and self.log_dir is not None
118125

119126
def add_log_hook(self, hook: Callable):
120127
self.log_hooks.append(hook)
@@ -158,6 +165,14 @@ def __len__(self):
158165
return len(self.data)
159166

160167
def __iter__(self):
168+
try:
169+
for x in self.iterator():
170+
yield x
171+
except Exception as e:
172+
traceback.print_exc()
173+
raise e
174+
175+
def iterator(self):
161176
worker_info = torch.utils.data.get_worker_info()
162177
self._wid = worker_info.id if worker_info is not None else 0
163178
# Now that we know we are in a worker instance, we can initialize per-worker things
@@ -189,9 +204,7 @@ def __iter__(self):
189204
else: # If we're not sampling the conditionals, then the idcs refer to listed preferences
190205
num_online = num_offline
191206
num_offline = 0
192-
cond_info = self.task.encode_conditional_information(
193-
steer_info=torch.stack([self.data[i] for i in idcs])
194-
)
207+
cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs]))
195208
trajs, flat_rewards = [], []
196209

197210
# Sample some on-policy data
@@ -250,14 +263,16 @@ def __iter__(self):
250263
# note: we convert back into natural rewards for logging purposes
251264
# (allows to take averages and plot in objective space)
252265
# TODO: implement that per-task (in case they don't apply the same beta and log transformations)
253-
rewards = torch.exp(log_rewards / cond_info["beta"])
266+
rewards = torch.exp(log_rewards / (cond_info["beta"] if "beta" in cond_info else 1.0))
254267
if num_online > 0 and self.log_dir is not None:
255268
self.log_generated(
256269
deepcopy(trajs[num_offline:]),
257270
deepcopy(rewards[num_offline:]),
258271
deepcopy(flat_rewards[num_offline:]),
259272
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
260273
)
274+
275+
extra_info = {}
261276
if num_online > 0:
262277
for hook in self.log_hooks:
263278
extra_info.update(

src/gflownet/envs/basic_graph_ctx.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from typing import Dict, List, Tuple
2+
3+
import networkx as nx
4+
import torch
5+
import torch_geometric.data as gd
6+
from networkx.algorithms.isomorphism import is_isomorphic as nx_is_isomorphic
7+
8+
from gflownet.envs.graph_building_env import (
9+
Graph,
10+
GraphAction,
11+
GraphActionType,
12+
GraphBuildingEnvContext,
13+
graph_without_edge,
14+
)
15+
from gflownet.utils.graphs import random_walk_probs
16+
17+
18+
def hashg(g):
19+
return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v")
20+
21+
22+
def is_isomorphic(u, v):
23+
return nx_is_isomorphic(u, v, lambda a, b: a == b, lambda a, b: a == b)
24+
25+
26+
class BasicGraphContext(GraphBuildingEnvContext):
27+
"""
28+
A basic graph generation context.
29+
30+
This simple environment context is designed to be used to test implementations. It only allows for AddNode and
31+
AddEdge actions, and is meant to be used within the BasicGraphTask to generate graphs of up to 7 nodes with
32+
only two possible node attributes, making the state space a total of ~70k states (which is nicely enumerable
33+
and allows us to compute p_theta(x) exactly for all x in the state space).
34+
"""
35+
36+
def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False):
37+
self.max_nodes = max_nodes
38+
self.output_gid = output_gid
39+
40+
self.node_attr_values = {
41+
"v": [0, 1], # Imagine this is as colors
42+
}
43+
self._num_rw_feat = 8
44+
45+
self.num_new_node_values = len(self.node_attr_values["v"])
46+
self.num_node_attr_logits = None
47+
self.num_node_dim = self.num_new_node_values + 1 + self._num_rw_feat
48+
self.num_node_attrs = 1
49+
self.num_edge_attr_logits = None
50+
self.num_edge_attrs = 0
51+
self.num_cond_dim = num_cond_dim
52+
self.num_edge_dim = 1
53+
self.edges_are_duplicated = True
54+
self.edges_are_unordered = True
55+
56+
# Order in which models have to output logits
57+
self.action_type_order = [
58+
GraphActionType.Stop,
59+
GraphActionType.AddNode,
60+
GraphActionType.AddEdge,
61+
]
62+
self.bck_action_type_order = [
63+
GraphActionType.RemoveNode,
64+
GraphActionType.RemoveEdge,
65+
]
66+
self.device = torch.device("cpu")
67+
self.graph_data = graph_data
68+
self.hash_to_graphs: Dict[str, int] = {}
69+
if graph_data is not None:
70+
states_hash = [hashg(i) for i in graph_data]
71+
for i, h, g in zip(range(len(graph_data)), states_hash, graph_data):
72+
self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)]
73+
74+
def get_graph_idx(self, g, default=None):
75+
h = hashg(g)
76+
if h not in self.hash_to_graphs and default is not None:
77+
return default
78+
bucket = self.hash_to_graphs[h]
79+
if len(bucket) == 1:
80+
return bucket[0][1]
81+
for i in bucket:
82+
if is_isomorphic(i[0], g):
83+
return i[1]
84+
if default is not None:
85+
return default
86+
raise ValueError(g)
87+
88+
def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
89+
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction"""
90+
act_type, act_row, act_col = [int(i) for i in action_idx]
91+
if fwd:
92+
t = self.action_type_order[act_type]
93+
else:
94+
t = self.bck_action_type_order[act_type]
95+
96+
if t is GraphActionType.Stop:
97+
return GraphAction(t)
98+
elif t is GraphActionType.AddNode:
99+
return GraphAction(t, source=act_row, value=self.node_attr_values["v"][act_col])
100+
elif t is GraphActionType.AddEdge:
101+
a, b = g.non_edge_index[:, act_row]
102+
return GraphAction(t, source=a.item(), target=b.item())
103+
elif t is GraphActionType.RemoveNode:
104+
return GraphAction(t, source=act_row)
105+
elif t is GraphActionType.RemoveEdge:
106+
a, b = g.edge_index[:, act_row * 2]
107+
return GraphAction(t, source=a.item(), target=b.item())
108+
109+
def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]:
110+
"""Translate a GraphAction to an index tuple"""
111+
if action.action is GraphActionType.Stop:
112+
row = col = 0
113+
type_idx = self.action_type_order.index(action.action)
114+
elif action.action is GraphActionType.AddNode:
115+
row = action.source
116+
col = self.node_attr_values["v"].index(action.value)
117+
type_idx = self.action_type_order.index(action.action)
118+
elif action.action is GraphActionType.AddEdge:
119+
# Here we have to retrieve the index in non_edge_index of an edge (s,t)
120+
# that's also possibly in the reverse order (t,s).
121+
# That's definitely not too efficient, can we do better?
122+
row = (
123+
(g.non_edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)
124+
+ (g.non_edge_index.T == torch.tensor([(action.target, action.source)])).prod(1)
125+
).argmax()
126+
col = 0
127+
type_idx = self.action_type_order.index(action.action)
128+
elif action.action is GraphActionType.RemoveNode:
129+
row = action.source
130+
col = 0
131+
type_idx = self.bck_action_type_order.index(action.action)
132+
elif action.action is GraphActionType.RemoveEdge:
133+
row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax()
134+
row = int(row) // 2 # edges are duplicated, but edge logits are not
135+
col = 0
136+
type_idx = self.bck_action_type_order.index(action.action)
137+
return (type_idx, int(row), int(col))
138+
139+
def graph_to_Data(self, g: Graph) -> gd.Data:
140+
"""Convert a networkx Graph to a torch geometric Data instance"""
141+
x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat))
142+
x[0, -1] = len(g.nodes) == 0
143+
remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
144+
for i, n in enumerate(g.nodes):
145+
ad = g.nodes[n]
146+
x[i, self.node_attr_values["v"].index(ad["v"])] = 1
147+
if g.degree(n) <= 1:
148+
remove_node_mask[i] = 1
149+
150+
remove_edge_mask = torch.zeros((len(g.edges), 1))
151+
for i, (u, v) in enumerate(g.edges):
152+
if g.degree(u) > 1 and g.degree(v) > 1:
153+
if nx.algorithms.is_connected(graph_without_edge(g, (u, v))):
154+
remove_edge_mask[i] = 1
155+
edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim))
156+
edge_index = (
157+
torch.tensor([e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long).reshape((-1, 2)).T
158+
)
159+
gc = nx.complement(g)
160+
non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T
161+
gid = self.get_graph_idx(g) if self.output_gid else 0
162+
163+
return self._preprocess(
164+
gd.Data(
165+
x,
166+
edge_index,
167+
edge_attr,
168+
non_edge_index=non_edge_index,
169+
stop_mask=torch.ones((1, 1)),
170+
add_node_mask=torch.ones((x.shape[0], self.num_new_node_values)) * (len(g) < self.max_nodes),
171+
add_edge_mask=torch.ones((non_edge_index.shape[1], 1)),
172+
remove_node_mask=remove_node_mask,
173+
remove_edge_mask=remove_edge_mask,
174+
gid=gid,
175+
)
176+
)
177+
178+
def _preprocess(self, g: gd.Data) -> gd.Data:
179+
if self._num_rw_feat > 0:
180+
g.x = torch.cat([g.x, random_walk_probs(g, self._num_rw_feat, skip_odd=True)], 1)
181+
return g
182+
183+
def collate(self, graphs: List[gd.Data]):
184+
"""Batch Data instances"""
185+
return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"])
186+
187+
def mol_to_graph(self, obj: Graph) -> Graph:
188+
return obj # This is already a graph
189+
190+
def graph_to_mol(self, g: Graph) -> Graph:
191+
# idem
192+
return g
193+
194+
def is_sane(self, g: Graph) -> bool:
195+
return True
196+
197+
def get_object_description(self, g: Graph, is_valid: bool) -> str:
198+
return str(self.get_graph_idx(g, -1))

src/gflownet/envs/graph_building_env.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G
331331
# TODO: should this be a method of GraphBuildingEnv? handle set_node_attr flags and so on?
332332
gn = Graph()
333333
# Choose an arbitrary starting point, add to the stack
334-
stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)]
334+
stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] if len(g.nodes) > 0 else []
335335
traj = []
336336
# This map keeps track of node labels in gn, since we have to start from 0
337337
relabeling_map: Dict[int, int] = {}
@@ -777,6 +777,7 @@ class GraphBuildingEnvContext:
777777
"""A context class defines what the graphs are, how they map to and from data"""
778778

779779
device: torch.device
780+
num_cond_dim: int = 0
780781

781782
def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction:
782783
"""Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction

0 commit comments

Comments
 (0)