Skip to content

Commit 9e489ab

Browse files
committed
Merge branch 'trunk' into add_small_graph_task
2 parents 8969470 + ec857a5 commit 9e489ab

23 files changed

+516
-249
lines changed

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
MAJOR="0"
2-
MINOR="0"
2+
MINOR="1"

docs/implementation_notes.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,21 @@ We separate experiment concerns in four categories:
1616
- The Trainer class is responsible for instanciating everything, and running the training & testing loop
1717

1818
Typically one would setup a new experiment by creating a class that inherits from `GFNTask` and a class that inherits from `GFNTrainer`. To implement a new MDP, one would create a class that inherits from `GraphBuildingEnvContext`.
19+
20+
21+
## Graphs
22+
23+
This library is built around the idea of generating graphs. We use the `networkx` library to represent graphs, and we use the `torch_geometric` library to represent graphs as tensors for the models. There is a fair amount of code that is dedicated to converting between the two representations.
24+
25+
Some notes:
26+
- graphs are (for now) assumed to be _undirected_. This is encoded for `torch_geometric` by duplicating the edges (contiguously) in both directions. Models still only produce one logit(-row) per edge, so the policy is still assumed to operate on undirected graphs.
27+
- When converting from `GraphAction`s (nx) to so-called `aidx`s, the `aidx`s are encoding-bound, i.e. they point to specific rows and columns in the torch encoding.
28+
29+
30+
### Graph policies & graph action categoricals
31+
32+
The code contains a specific categorical distribution type for graph actions, `GraphActionCategorical`. This class contains logic to sample from concatenated sets of logits accross a minibatch.
33+
34+
Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.
35+
36+
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ dependencies = [
7575
"botorch",
7676
"pyro-ppl",
7777
"gpytorch",
78-
"omegaconf",
78+
"omegaconf>=2.3",
7979
]
8080

8181
[project.optional-dependencies]

src/gflownet/algo/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class AlgoConfig:
9191
offline_ratio: float
9292
The ratio of samples drawn from `self.training_data` during training. The rest is drawn from
9393
`self.sampling_model`
94+
valid_offline_ratio: float
95+
Idem but for validation, and `self.test_data`.
9496
train_random_action_prob : float
9597
The probability of taking a random action during training
9698
valid_random_action_prob : float
@@ -108,6 +110,7 @@ class AlgoConfig:
108110
max_edges: int = 128
109111
illegal_action_logreward: float = -100
110112
offline_ratio: float = 0.5
113+
valid_offline_ratio: float = 1
111114
train_random_action_prob: float = 0.0
112115
valid_random_action_prob: float = 0.0
113116
valid_sample_cond_info: bool = True

src/gflownet/algo/trajectory_balance.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def logZ(self, cond_info: Tensor) -> Tensor:
3131

3232

3333
class TrajectoryBalance(GFNAlgorithm):
34-
""" """
34+
"""TB implementation, see
35+
"Trajectory Balance: Improved Credit Assignment in GFlowNets Nikolay Malkin, Moksh Jain,
36+
Emmanuel Bengio, Chen Sun, Yoshua Bengio"
37+
https://arxiv.org/abs/2201.13259"""
3538

3639
def __init__(
3740
self,

src/gflownet/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ class Config:
5252
----------
5353
log_dir : str
5454
The directory where to store logs, checkpoints, and samples.
55+
device : str
56+
The device to use for training (either "cpu" or "cuda[:<device_id>]")
5557
seed : int
5658
The random seed
5759
validate_every : int
5860
The number of training steps after which to validate the model
5961
checkpoint_every : Optional[int]
6062
The number of training steps after which to checkpoint the model
63+
print_every : int
64+
The number of training steps after which to print the training loss
6165
start_at_step : int
6266
The training step to start at (default: 0)
6367
num_final_gen_steps : Optional[int]
@@ -78,9 +82,11 @@ class Config:
7882

7983
log_dir: str = MISSING
8084
log_sampled_data: bool = True
85+
device: str = "cuda"
8186
seed: int = 0
8287
validate_every: int = 1000
8388
checkpoint_every: Optional[int] = None
89+
print_every: int = 100
8490
start_at_step: int = 0
8591
num_final_gen_steps: Optional[int] = None
8692
num_training_steps: int = 10_000

src/gflownet/data/qm9.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pandas as pd
55
import rdkit.Chem as Chem
6+
import torch
67
from torch.utils.data import Dataset
78

89

@@ -39,4 +40,23 @@ def __len__(self):
3940
return len(self.idcs)
4041

4142
def __getitem__(self, idx):
42-
return (Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), self.df[self.target][self.idcs[idx]])
43+
return (
44+
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
45+
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
46+
)
47+
48+
49+
def convert_h5():
50+
# File obtained from
51+
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
52+
# (from http://quantum-machine.org/datasets/)
53+
f = tarfile.TarFile("qm9.xyz.tar", "r")
54+
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
55+
all_mols = []
56+
for pt in f:
57+
pt = f.extractfile(pt) # type: ignore
58+
data = pt.read().decode().splitlines() # type: ignore
59+
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
60+
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
61+
store = pd.HDFStore("qm9.h5", "w")
62+
store["df"] = df

src/gflownet/data/sampling_iterator.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from gflownet.config import Config
1616
from gflownet.data.replay_buffer import ReplayBuffer
17+
from gflownet.envs.graph_building_env import GraphActionCategorical
1718

1819

1920
class SamplingIterator(IterableDataset):
@@ -30,11 +31,12 @@ def __init__(
3031
self,
3132
dataset: Dataset,
3233
model: nn.Module,
33-
cfg: Config,
3434
ctx,
3535
algo,
3636
task,
3737
device,
38+
batch_size: int = 1,
39+
illegal_action_logreward: float = -50,
3840
ratio: float = 0.5,
3941
stream: bool = True,
4042
replay_buffer: ReplayBuffer = None,
@@ -51,14 +53,21 @@ def __init__(
5153
model: nn.Module
5254
The model we sample from (must be on CUDA already or share_memory() must be called so that
5355
parameters are synchronized between each worker)
56+
ctx:
57+
The context for the environment, e.g. a MolBuildingEnvContext instance
58+
algo:
59+
The training algorithm, e.g. a TrajectoryBalance instance
60+
task: GFNTask
61+
A Task instance, e.g. a MakeRingsTask instance
62+
device: torch.device
63+
The device the model is on
5464
replay_buffer: ReplayBuffer
5565
The replay buffer for training on past data
5666
batch_size: int
5767
The number of trajectories, each trajectory will be comprised of many graphs, so this is
5868
_not_ the batch size in terms of the number of graphs (that will depend on the task)
59-
algo:
60-
The training algorithm, e.g. a TrajectoryBalance instance
61-
task: ConditionalTask
69+
illegal_action_logreward: float
70+
The logreward for invalid trajectories
6271
ratio: float
6372
The ratio of offline trajectories in the batch.
6473
stream: bool
@@ -69,13 +78,16 @@ def __init__(
6978
sample_cond_info: bool
7079
If True (default), then the dataset is a dataset of points used in offline training.
7180
If False, then the dataset is a dataset of preferences (e.g. used to validate the model)
72-
81+
random_action_prob: float
82+
The probability of taking a random action, passed to the graph sampler
83+
init_train_iter: int
84+
The initial training iteration, incremented and passed to task.sample_conditional_information
7385
"""
74-
self.cfg = cfg
7586
self.data = dataset
7687
self.model = model
7788
self.replay_buffer = replay_buffer
78-
self.batch_size = self.cfg.algo.global_batch_size
89+
self.batch_size = batch_size
90+
self.illegal_action_logreward = illegal_action_logreward
7991
self.offline_batch_size = int(np.ceil(self.batch_size * ratio))
8092
self.online_batch_size = int(np.floor(self.batch_size * (1 - ratio)))
8193
self.ratio = ratio
@@ -89,6 +101,8 @@ def __init__(
89101
self.random_action_prob = random_action_prob
90102
self.hindsight_ratio = hindsight_ratio
91103
self.train_it = init_train_iter
104+
self.do_validate_batch = False # Turn this on for debugging
105+
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") # TODO: make this a proper flag
92106

93107
# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
94108
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
@@ -100,7 +114,10 @@ def __init__(
100114
# don't want to initialize per-worker things just yet, such as where the log the worker writes
101115
# to. This must be done in __iter__, which is called by the DataLoader once this instance
102116
# has been copied into a new python process.
103-
self.log_dir = log_dir if cfg.log_sampled_data else None
117+
import warnings
118+
119+
warnings.warn("Fix dependency on cfg.log_sampled_data")
120+
self.log_dir = log_dir # if cfg.log_sampled_data else None
104121
self.log = SQLiteLog()
105122
self.log_hooks: List[Callable] = []
106123
# TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db
@@ -122,6 +139,9 @@ def _idx_iterator(self):
122139
if n == 0:
123140
yield np.arange(0, 0)
124141
return
142+
assert (
143+
self.offline_batch_size > 0
144+
), "offline_batch_size must be > 0 if not streaming and len(data) > 0 (have you set ratio=0?)"
125145
if worker_info is None: # no multi-processing
126146
start, end, wid = 0, n, -1
127147
else: # split the data into chunks (per-worker)
@@ -232,9 +252,10 @@ def iterator(self):
232252
# Compute scalar rewards from conditional information & flat rewards
233253
flat_rewards = torch.stack(flat_rewards)
234254
log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards)
235-
log_rewards[torch.logical_not(is_valid)] = self.cfg.algo.illegal_action_logreward
255+
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward
236256

237257
# Computes some metrics
258+
extra_info = {}
238259
if not self.sample_cond_info:
239260
# If we're using a dataset of preferences, the user may want to know the id of the preference
240261
for i, j in zip(trajs, idcs):
@@ -304,7 +325,7 @@ def iterator(self):
304325
cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards(
305326
cond_info, log_rewards, flat_rewards, hindsight_idxs
306327
)
307-
log_rewards[torch.logical_not(is_valid)] = self.cfg.algo.illegal_action_logreward
328+
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward
308329

309330
# Construct batch
310331
batch = self.algo.construct_batch(trajs, cond_info["encoding"], log_rewards)
@@ -317,9 +338,37 @@ def iterator(self):
317338
# TODO: we could very well just pass the cond_info dict to construct_batch above,
318339
# and the algo can decide what it wants to put in the batch object
319340

341+
# Only activate for debugging your environment or dataset (e.g. the dataset could be
342+
# generating trajectories with illegal actions)
343+
if self.do_validate_batch:
344+
self.validate_batch(batch, trajs)
345+
320346
self.train_it += worker_info.num_workers if worker_info is not None else 1
321347
yield batch
322348

349+
def validate_batch(self, batch, trajs):
350+
for actions, atypes in [(batch.actions, self.ctx.action_type_order)] + (
351+
[(batch.bck_actions, self.ctx.bck_action_type_order)]
352+
if hasattr(batch, "bck_actions") and hasattr(self.ctx, "bck_action_type_order")
353+
else []
354+
):
355+
mask_cat = GraphActionCategorical(
356+
batch,
357+
[self.model._action_type_to_mask(t, batch) for t in atypes],
358+
[self.model._action_type_to_key[t] for t in atypes],
359+
[None for _ in atypes],
360+
)
361+
masked_action_is_used = 1 - mask_cat.log_prob(actions, logprobs=mask_cat.logits)
362+
num_trajs = len(trajs)
363+
batch_idx = torch.arange(num_trajs, device=batch.x.device).repeat_interleave(batch.traj_lens)
364+
first_graph_idx = torch.zeros_like(batch.traj_lens)
365+
torch.cumsum(batch.traj_lens[:-1], 0, out=first_graph_idx[1:])
366+
if masked_action_is_used.sum() != 0:
367+
invalid_idx = masked_action_is_used.argmax().item()
368+
traj_idx = batch_idx[invalid_idx].item()
369+
timestep = invalid_idx - first_graph_idx[traj_idx].item()
370+
raise ValueError("Found an action that was masked out", trajs[traj_idx]["traj"][timestep])
371+
323372
def log_generated(self, trajs, rewards, flat_rewards, cond_info):
324373
if self.log_molecule_smis:
325374
mols = [

src/gflownet/envs/graph_building_env.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,9 +271,9 @@ def add_parent(a, new_g):
271271
GraphAction(GraphActionType.AddNode, source=anchor, value=g.nodes[i]["v"]),
272272
new_g,
273273
)
274-
if len(g.nodes) == 1:
274+
if len(g.nodes) == 1 and len(g.nodes[i]) == 1:
275275
# The final node is degree 0, need this special case to remove it
276-
# and end up with S0, the empty graph root
276+
# and end up with S0, the empty graph root (but only if it has no attrs except 'v')
277277
add_parent(
278278
GraphAction(GraphActionType.AddNode, source=0, value=g.nodes[i]["v"]),
279279
graph_without_node(g, i),

0 commit comments

Comments
 (0)