Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add small graph task #101

Draft
wants to merge 8 commits into
base: trunk
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/implementation_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ The code contains a specific categorical distribution type for graph actions, `G

Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor.

The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution.
4 changes: 3 additions & 1 deletion src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def not_done(lst):
]
if self.sample_temp != 1:
sample_cat = copy.copy(fwd_cat)
sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits]
sample_cat.logits = [
i * m / self.sample_temp - 1000 * (1 - m) for i, m in zip(fwd_cat.logits, fwd_cat.masks)
]
actions = sample_cat.sample()
else:
actions = fwd_cat.sample()
Expand Down
23 changes: 10 additions & 13 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
import torch_geometric.data as gd
from torch import Tensor
from torch_scatter import scatter, scatter_sum
from torch_scatter import scatter, scatter_logsumexp, scatter_sum

from gflownet.algo.config import TBVariant
from gflownet.algo.graph_sampling import GraphSampler
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
)
if self.cfg.variant == TBVariant.SubTB1:
self._subtb_max_len = self.global_cfg.algo.max_len + 2
self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info?
self._init_subtb(torch.device(cfg.device)) # TODO: where are we getting device info?

def create_training_data_from_own_samples(
self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float
Expand Down Expand Up @@ -187,7 +187,11 @@ def create_training_data_from_graphs(self, graphs):
trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}]
A list of trajectories.
"""
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
if hasattr(self.ctx, "relabel"):
relabel = self.ctx.relabel
else:
relabel = lambda *x: x # noqa: E731
trajs = [{"traj": [relabel(*t) for t in generate_forward_trajectory(i)]} for i in graphs]
for traj in trajs:
n_back = [
self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent)
Expand Down Expand Up @@ -365,22 +369,15 @@ def compute_batch_losses(
# Indicate that the `batch` corresponding to each action is the above
ip_log_prob = fwd_cat.log_prob(batch.ip_actions, batch=ip_batch_idces)
# take the logsumexp (because we want to sum probabilities, not log probabilities)
# TODO: numerically stable version:
p = scatter(ip_log_prob.exp(), ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum")
# As a (reasonable) band-aid, ignore p < 1e-30, this will prevent underflows due to
# scatter(small number) = 0 on CUDA
log_p_F = p.clamp(1e-30).log()
log_p_F = scatter_logsumexp(ip_log_prob, ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])

if self.cfg.do_parameterize_p_b:
# Now we repeat this but for the backward policy
bck_ip_batch_idces = torch.arange(batch.bck_ip_lens.shape[0], device=dev).repeat_interleave(
batch.bck_ip_lens
)
bck_ip_log_prob = bck_cat.log_prob(batch.bck_ip_actions, batch=bck_ip_batch_idces)
bck_p = scatter(
bck_ip_log_prob.exp(), bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum"
)
log_p_B = bck_p.clamp(1e-30).log()
log_p_B = scatter_logsumexp(bck_ip_log_prob, bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0])
else:
# Else just naively take the logprob of the actions we took
log_p_F = fwd_cat.log_prob(batch.actions)
Expand Down Expand Up @@ -564,7 +561,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths):
cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=dev), traj_lengths]), 0).long()
total_loss = torch.zeros(num_trajs, device=dev)
ar = torch.arange(max_len, device=dev)
car = torch.cumsum(ar, 0)
car = torch.cumsum(ar, 0) if self.length_normalize_losses else torch.ones_like(ar)
F_and_R = torch.cat([F, R])
R_start = F.shape[0]
for ep in range(traj_lengths.shape[0]):
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Config:
"""

log_dir: str = MISSING
log_sampled_data: bool = True
device: str = "cuda"
seed: int = 0
validate_every: int = 1000
Expand Down
25 changes: 20 additions & 5 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sqlite3
import traceback
from collections.abc import Iterable
from copy import deepcopy
from typing import Callable, List
Expand Down Expand Up @@ -110,9 +111,14 @@ def __init__(
# don't want to initialize per-worker things just yet, such as where the log the worker writes
# to. This must be done in __iter__, which is called by the DataLoader once this instance
# has been copied into a new python process.
self.log_dir = log_dir
import warnings

warnings.warn("Fix dependency on cfg.log_sampled_data")
self.log_dir = log_dir # if cfg.log_sampled_data else None
self.log = SQLiteLog()
self.log_hooks: List[Callable] = []
# TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db
self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") and self.log_dir is not None

def add_log_hook(self, hook: Callable):
self.log_hooks.append(hook)
Expand Down Expand Up @@ -156,6 +162,14 @@ def __len__(self):
return len(self.data)

def __iter__(self):
try:
for x in self.iterator():
yield x
except Exception as e:
traceback.print_exc()
raise e

def iterator(self):
worker_info = torch.utils.data.get_worker_info()
self._wid = worker_info.id if worker_info is not None else 0
# Now that we know we are in a worker instance, we can initialize per-worker things
Expand Down Expand Up @@ -187,9 +201,7 @@ def __iter__(self):
else: # If we're not sampling the conditionals, then the idcs refer to listed preferences
num_online = num_offline
num_offline = 0
cond_info = self.task.encode_conditional_information(
steer_info=torch.stack([self.data[i] for i in idcs])
)
cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs]))
trajs, flat_rewards = [], []

# Sample some on-policy data
Expand Down Expand Up @@ -244,14 +256,16 @@ def __iter__(self):
# note: we convert back into natural rewards for logging purposes
# (allows to take averages and plot in objective space)
# TODO: implement that per-task (in case they don't apply the same beta and log transformations)
rewards = torch.exp(log_rewards / cond_info["beta"])
rewards = torch.exp(log_rewards / (cond_info["beta"] if "beta" in cond_info else 1.0))
if num_online > 0 and self.log_dir is not None:
self.log_generated(
deepcopy(trajs[num_offline:]),
deepcopy(rewards[num_offline:]),
deepcopy(flat_rewards[num_offline:]),
{k: v[num_offline:] for k, v in deepcopy(cond_info).items()},
)

extra_info = {}
if num_online > 0:
for hook in self.log_hooks:
extra_info.update(
Expand Down Expand Up @@ -314,6 +328,7 @@ def __iter__(self):
batch.preferences = cond_info.get("preferences", None)
batch.focus_dir = cond_info.get("focus_dir", None)
batch.extra_info = extra_info
batch.trajs = trajs
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

Expand Down
Loading