Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fixes from the maxent paper #116

Merged
merged 36 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
94a5a4b
add fixes from the maxent paper
SobhanMP Feb 6, 2024
cc13ae6
fix style
SobhanMP Feb 6, 2024
a4d467b
fix style pt.2
SobhanMP Feb 6, 2024
1491bcf
Merge branch 'trunk' into qm9
SobhanMP Feb 8, 2024
21375ed
style
SobhanMP Feb 9, 2024
d18d159
rename print=True to print_hps
SobhanMP Feb 9, 2024
82de579
fix qm9 problems
SobhanMP Feb 9, 2024
1c34443
docu
SobhanMP Feb 9, 2024
ce5a565
rename chunked sim
SobhanMP Feb 9, 2024
b149ef2
move traj len out of reward percentilehook
SobhanMP Feb 9, 2024
86d67df
remove ruamel
SobhanMP Feb 9, 2024
d232bcb
tox
SobhanMP Feb 9, 2024
1bf8724
add flag to store all checkpoints
SobhanMP Feb 9, 2024
2ea531f
fix moohook in seh_frag and remove it in qm9
SobhanMP Feb 9, 2024
0760b87
add comment about the graceful termination of moostats
SobhanMP Feb 9, 2024
6dcce36
add a flag for predicting n
SobhanMP Feb 10, 2024
d9acdb6
REMOVE USELESS QM9 THING
SobhanMP Feb 11, 2024
1669eba
Merge branch 'qm9' of github.com:SobhanMP/gflownet_ into qm9
SobhanMP Feb 11, 2024
4f350ec
fix typo
SobhanMP Feb 11, 2024
2a20d6a
upgrade qm9
SobhanMP Feb 11, 2024
1ac2cd0
fmt
SobhanMP Feb 11, 2024
dff6e14
fmt
SobhanMP Feb 11, 2024
d7ae387
broadcast back the invalid results
SobhanMP Feb 11, 2024
994cbcb
add compute_reward_from_graph method to seh
SobhanMP Feb 11, 2024
cd2893a
use compute_reward_from_graph in seh_moo
SobhanMP Feb 11, 2024
b8930e0
unify trminate and to_close
SobhanMP Feb 16, 2024
adaa8a6
Merge branch 'qm9' of github.com:SobhanMP/gflownet_ into qm9
SobhanMP Feb 16, 2024
673a880
fmt
SobhanMP Feb 16, 2024
71d62d5
ft
SobhanMP Feb 16, 2024
4f058e5
f
SobhanMP Feb 16, 2024
f947bb3
fix typo
SobhanMP Feb 16, 2024
d36952b
fix runtime errors
SobhanMP Feb 17, 2024
ae7d740
fmt
SobhanMP Feb 17, 2024
922dc8b
close hdf5 gracefully
SobhanMP Feb 17, 2024
b5ddddf
fmt
SobhanMP Feb 17, 2024
a715adb
revert default num_graph_out
SobhanMP Feb 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ dev = [
"types-pkg_resources",
# Security pin
"gitpython>=3.1.30",
"ruamel.yaml",
SobhanMP marked this conversation as resolved.
Show resolved Hide resolved
]

[[project.authors]]
Expand Down
3 changes: 3 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class AlgoConfig:
Idem but for validation, and `self.test_data`.
train_random_action_prob : float
The probability of taking a random action during training
train_det_after: Optional[int]
Do not take random actions after this number of steps
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Expand All @@ -126,6 +128,7 @@ class AlgoConfig:
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
train_random_action_prob: float = 0.0
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
Expand Down
53 changes: 32 additions & 21 deletions src/gflownet/data/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,70 @@
import pandas as pd
import rdkit.Chem as Chem
import torch
from rdkit.Chem import QED, Descriptors
from torch.utils.data import Dataset

from gflownet.utils import sascore


class QM9Dataset(Dataset):
def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9):
def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9):
if h5_file is not None:
self.df = pd.HDFStore(h5_file, "r")["df"]
elif xyz_file is not None:
self.load_tar()
SobhanMP marked this conversation as resolved.
Show resolved Hide resolved
rng = np.random.default_rng(split_seed)
idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required
rng.shuffle(idcs)
self.target = target
self.targets = targets
if train:
self.idcs = idcs[: int(np.floor(ratio * len(self.df)))]
else:
self.idcs = idcs[int(np.floor(ratio * len(self.df))) :]
self.mol_to_graph = lambda x: x

def setup(self, task, ctx):
self.mol_to_graph = ctx.mol_to_graph

def get_stats(self, percentile=0.95):
y = self.df[self.target]
def get_stats(self, target=None, percentile=0.95):
if target is None:
target = self.targets[0]
y = self.df[target]
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)]

def load_tar(self, xyz_file):
f = tarfile.TarFile(xyz_file, "r")
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
all_mols = []
for pt in f:
pt = f.extractfile(pt)
data = pt.read().decode().splitlines()
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
self.df = load_tar(xyz_file)

def __len__(self):
return len(self.idcs)

def __getitem__(self, idx):
return (
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])),
torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(),
)


def convert_h5():
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
f = tarfile.TarFile("qm9.xyz.tar", "r")
def load_tar(xyz_file):
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
f = tarfile.TarFile(xyz_file, "r")
all_mols = []
for pt in f:
pt = f.extractfile(pt) # type: ignore
data = pt.read().decode().splitlines() # type: ignore
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
store = pd.HDFStore("qm9.h5", "w")
store["df"] = df
mols = df["SMILES"].map(Chem.MolFromSmiles)
df["qed"] = mols.map(QED.qed)
df["sa"] = mols.map(sascore.calculateScore)
df["mw"] = mols.map(Descriptors.MolWt)
return df


def convert_h5():
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
df = load_tar("qm9.xyz.tar")
with pd.HDFStore("qm9.h5", "w") as store:
store["df"] = df
26 changes: 22 additions & 4 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sqlite3
from collections.abc import Iterable
from copy import deepcopy
from typing import Callable, List
from typing import Callable, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -40,6 +40,7 @@ def __init__(
log_dir: str = None,
sample_cond_info: bool = True,
random_action_prob: float = 0.0,
det_after: Optional[int] = None,
hindsight_ratio: float = 0.0,
init_train_iter: int = 0,
):
Expand Down Expand Up @@ -99,7 +100,8 @@ def __init__(
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging

self.iter = 0
self.det_after = det_after
# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
# up with 2*batch_size accidentally
Expand All @@ -122,7 +124,10 @@ def _idx_iterator(self):
if self.stream:
# If we're streaming data, just sample `offline_batch_size` indices
while True:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
if self.offline_batch_size == 0 or len(self.data) == 0:
yield np.arange(0, 0)
else:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
else:
# Otherwise, figure out which indices correspond to this worker
worker_info = torch.utils.data.get_worker_info()
Expand Down Expand Up @@ -156,6 +161,9 @@ def __len__(self):
return len(self.data)

def __iter__(self):
self.iter += 1
if self.det_after is not None and self.iter > self.det_after:
self.random_action_prob = 0
worker_info = torch.utils.data.get_worker_info()
self._wid = worker_info.id if worker_info is not None else 0
# Now that we know we are in a worker instance, we can initialize per-worker things
Expand All @@ -181,6 +189,7 @@ def __iter__(self):
flat_rewards = (
list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else []
)

trajs = self.algo.create_training_data_from_graphs(
graphs, self.model, cond_info["encoding"][:num_offline], 0
)
Expand Down Expand Up @@ -236,8 +245,13 @@ def __iter__(self):
log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards)
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

assert len(trajs) == num_online + num_offline
# Computes some metrics
extra_info = {}
extra_info = {"random_action_prob": self.random_action_prob}
if num_online > 0:
H = sum(i["fwd_logprob"] for i in trajs[num_offline:])
extra_info["entropy"] = -H / num_online
extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]])
if not self.sample_cond_info:
# If we're using a dataset of preferences, the user may want to know the id of the preference
for i, j in zip(trajs, idcs):
Expand Down Expand Up @@ -316,6 +330,10 @@ def __iter__(self):
batch.preferences = cond_info.get("preferences", None)
batch.focus_dir = cond_info.get("focus_dir", None)
batch.extra_info = extra_info
if self.ctx.has_n():
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

Expand Down
84 changes: 84 additions & 0 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
from math import log
from typing import List, Tuple

import networkx as nx
import numpy as np
import rdkit.Chem as Chem
import torch
import torch_geometric.data as gd
from scipy import special

from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext
from gflownet.models import bengio2021flow
Expand Down Expand Up @@ -85,6 +88,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu
GraphActionType.RemoveEdgeAttr,
]
self.device = torch.device("cpu")
self.n_counter = NCounter()
self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms())

def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True):
Expand Down Expand Up @@ -355,6 +359,86 @@ def object_to_log_repr(self, g: Graph):
"""Convert a Graph to a string representation"""
return Chem.MolToSmiles(self.graph_to_mol(g))

def has_n(self) -> bool:
return True

def log_n(self, g: Graph) -> int:
return self.n_counter(g)


class NCounter:
"""
Dynamic program to calculate the number of trajectories to a state.
See Appendix D of "Maximum entropy GFlowNets with soft Q-learning"
by Mohammadpour et al 2024 (https://arxiv.org/abs/2312.14331) for a proof.
"""

def __init__(self):
# Hold the log factorial
self.cache = [0.0, 0.0]

def lfac(self, arg: int):
while arg >= len(self.cache):
self.cache.append(log(len(self.cache)) + self.cache[-1])
return self.cache[arg]

def lcomb(self, x, y):
# log c(x, y) = log (x! / (y! (x - y)!))
assert x >= y
return self.lfac(x) - self.lfac(y) - self.lfac(x - y)

@staticmethod
def root_tree(og: nx.Graph, x):
g = nx.DiGraph(nx.create_empty_copy(og))
visited = np.zeros(len(g), bool)
visited[x] = True
q = [x]
while len(q) > 0: # print(i, x)
x = q.pop()
for i in nx.neighbors(og, x):
if not visited[i]:
visited[i] = True
g.add_edge(x, i, **(og.get_edge_data(x, i) | og.get_edge_data(i, x)))
q.append(i)

return g

def f(self, g, x):
elem = np.full((len(g),), -1, int)
ways = np.full((len(g),), -1, float)

def _f(x):
if elem[x] < 0:
e, w = 0, 0
for i in nx.neighbors(g, x):
e1, w1 = _f(i)
# edge feature
f = len(g.get_edge_data(x, i))
for i in range(f):
w1 += np.log(e1 + i)
e1 += f

w = w + w1 + self.lcomb(e + e1, e)
e = e + e1

elem[x] = e + 1
ways[x] = w
return elem[x], ways[x]

return _f(x)[1]

def __call__(self, g):
if len(g) == 0:
return 0

acc = []
for i in nx.nodes(g):
rg = self.root_tree(g, i)
x = self.f(rg, i)
acc.append(x)

return special.logsumexp(acc)


def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None):
if numiters is None:
Expand Down
49 changes: 48 additions & 1 deletion src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __init__(
slice_dict[k].to(dev) if k is not None else torch.arange(graphs.num_graphs + 1, device=dev) for k in keys
]
self.logprobs = None
self.log_n = None

if deduplicate_edge_index and "edge_index" in keys:
for idx, k in enumerate(keys):
Expand All @@ -563,6 +564,8 @@ def detach(self):
new.logits = [i.detach() for i in new.logits]
if new.logprobs is not None:
new.logprobs = [i.detach() for i in new.logprobs]
if new.log_n is not None:
new.log_n = new.log_n.detach()
return new

def to(self, device):
Expand All @@ -572,10 +575,28 @@ def to(self, device):
self.slice = [i.to(device) for i in self.slice]
if self.logprobs is not None:
self.logprobs = [i.to(device) for i in self.logprobs]
if self.log_n is not None:
self.log_n = self.log_n.to(device)
if self.masks is not None:
self.masks = [i.to(device) for i in self.masks]
return self

def log_n_actions(self):
if self.log_n is None:
self.log_n = (
sum(
[
scatter(m.broadcast_to(i.shape).int().sum(1), b, dim=0, dim_size=self.num_graphs, reduce="sum")
for m, i, b in zip(self.masks, self.logits, self.batch)
]
)
.clamp(1)
.float()
.log()
.clamp(1)
)
return self.log_n

def _compute_batchwise_max(
self,
x: List[torch.Tensor],
Expand Down Expand Up @@ -674,8 +695,25 @@ def sample(self) -> List[Tuple[int, int, int]]:
u = [torch.rand(i.shape, device=self.dev) for i in self.logits]
# Gumbel noise
gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)]

if self.masks is not None:
gumbel_safe = [
torch.where(
mask == 1,
torch.maximum(
x,
torch.nextafter(
torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype)
).to(x.device),
),
torch.finfo(x.dtype).min,
)
for x, mask in zip(gumbel, self.masks)
]
else:
gumbel_safe = gumbel
# Take the argmax
return self.argmax(x=gumbel)
return self.argmax(x=gumbel_safe)

def argmax(
self,
Expand Down Expand Up @@ -922,3 +960,12 @@ def object_to_log_repr(self, g: Graph) -> str:
return json.dumps(
[[(i, g.nodes[i]) for i in g.nodes], [(e, g.edges[e]) for e in g.edges]], separators=(",", ":")
)

def has_n(self) -> bool:
return False

def log_n(self, g) -> float:
return 0.0

def traj_log_n(self, traj):
return [self.log_n(g) for g, _ in traj]
Loading
Loading