Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fixes from the maxent paper #116

Merged
merged 36 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
94a5a4b
add fixes from the maxent paper
SobhanMP Feb 6, 2024
cc13ae6
fix style
SobhanMP Feb 6, 2024
a4d467b
fix style pt.2
SobhanMP Feb 6, 2024
1491bcf
Merge branch 'trunk' into qm9
SobhanMP Feb 8, 2024
21375ed
style
SobhanMP Feb 9, 2024
d18d159
rename print=True to print_hps
SobhanMP Feb 9, 2024
82de579
fix qm9 problems
SobhanMP Feb 9, 2024
1c34443
docu
SobhanMP Feb 9, 2024
ce5a565
rename chunked sim
SobhanMP Feb 9, 2024
b149ef2
move traj len out of reward percentilehook
SobhanMP Feb 9, 2024
86d67df
remove ruamel
SobhanMP Feb 9, 2024
d232bcb
tox
SobhanMP Feb 9, 2024
1bf8724
add flag to store all checkpoints
SobhanMP Feb 9, 2024
2ea531f
fix moohook in seh_frag and remove it in qm9
SobhanMP Feb 9, 2024
0760b87
add comment about the graceful termination of moostats
SobhanMP Feb 9, 2024
6dcce36
add a flag for predicting n
SobhanMP Feb 10, 2024
d9acdb6
REMOVE USELESS QM9 THING
SobhanMP Feb 11, 2024
1669eba
Merge branch 'qm9' of github.com:SobhanMP/gflownet_ into qm9
SobhanMP Feb 11, 2024
4f350ec
fix typo
SobhanMP Feb 11, 2024
2a20d6a
upgrade qm9
SobhanMP Feb 11, 2024
1ac2cd0
fmt
SobhanMP Feb 11, 2024
dff6e14
fmt
SobhanMP Feb 11, 2024
d7ae387
broadcast back the invalid results
SobhanMP Feb 11, 2024
994cbcb
add compute_reward_from_graph method to seh
SobhanMP Feb 11, 2024
cd2893a
use compute_reward_from_graph in seh_moo
SobhanMP Feb 11, 2024
b8930e0
unify trminate and to_close
SobhanMP Feb 16, 2024
adaa8a6
Merge branch 'qm9' of github.com:SobhanMP/gflownet_ into qm9
SobhanMP Feb 16, 2024
673a880
fmt
SobhanMP Feb 16, 2024
71d62d5
ft
SobhanMP Feb 16, 2024
4f058e5
f
SobhanMP Feb 16, 2024
f947bb3
fix typo
SobhanMP Feb 16, 2024
d36952b
fix runtime errors
SobhanMP Feb 17, 2024
ae7d740
fmt
SobhanMP Feb 17, 2024
922dc8b
close hdf5 gracefully
SobhanMP Feb 17, 2024
b5ddddf
fmt
SobhanMP Feb 17, 2024
a715adb
revert default num_graph_out
SobhanMP Feb 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies = [
"pyro-ppl",
"gpytorch",
"omegaconf>=2.3",
"pandas", # needed for QM9 and HDF5 support.
]

[project.optional-dependencies]
Expand Down
6 changes: 6 additions & 0 deletions src/gflownet/algo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class TBConfig:
Whether to correct for idempotent actions
do_parameterize_p_b : bool
Whether to parameterize the P_B distribution (otherwise it is uniform)
do_predict_n : bool
Whether to predict the number of paths in the graph
do_length_normalize : bool
Whether to normalize the loss by the length of the trajectory
subtb_max_len : int
Expand All @@ -45,6 +47,7 @@ class TBConfig:
variant: TBVariant = TBVariant.TB
do_correct_idempotent: bool = False
do_parameterize_p_b: bool = False
do_predict_n: bool = False
do_sample_p_b: bool = False
do_length_normalize: bool = False
subtb_max_len: int = 128
Expand Down Expand Up @@ -109,6 +112,8 @@ class AlgoConfig:
Idem but for validation, and `self.test_data`.
train_random_action_prob : float
The probability of taking a random action during training
train_det_after: Optional[int]
Do not take random actions after this number of steps
valid_random_action_prob : float
The probability of taking a random action during validation
valid_sample_cond_info : bool
Expand All @@ -126,6 +131,7 @@ class AlgoConfig:
offline_ratio: float = 0.5
valid_offline_ratio: float = 1
train_random_action_prob: float = 0.0
train_det_after: Optional[int] = None
valid_random_action_prob: float = 0.0
valid_sample_cond_info: bool = True
sampling_tau: float = 0.0
Expand Down
1 change: 1 addition & 0 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def sample_backward_from_graphs(
def not_done(lst):
return [e for i, e in enumerate(lst) if not done[i]]

# TODO: This should be doable.
if random_action_prob > 0:
raise NotImplementedError("Random action not implemented for backward sampling")

Expand Down
14 changes: 12 additions & 2 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -206,14 +207,23 @@ def create_training_data_from_graphs(
return self.graph_sampler.sample_backward_from_graphs(
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob
)
trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs]
trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs]
for traj in trajs:
n_back = [
self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent)
for gp, _ in traj["traj"][1:]
] + [1]
traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device)
traj["result"] = traj["traj"][-1][0]
if self.cfg.do_parameterize_p_b:
traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]]
# There needs to be an additonal node when we're parameterizing P_B,
# See sampling with parametrized P_B
traj["traj"].append(deepcopy(traj["traj"][-1]))
traj["is_sink"] = [0 for _ in traj["traj"]]
traj["is_sink"][-1] = 1
traj["is_sink"][-2] = 1
assert len(traj["bck_a"]) == len(traj["traj"]) == len(traj["is_sink"])
return trajs

def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True):
Expand Down
3 changes: 3 additions & 0 deletions src/gflownet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Config:
The number of training steps after which to validate the model
checkpoint_every : Optional[int]
The number of training steps after which to checkpoint the model
store_all_checkpoints : bool
Whether to store all checkpoints or only the last one
print_every : int
The number of training steps after which to print the training loss
start_at_step : int
Expand All @@ -85,6 +87,7 @@ class Config:
seed: int = 0
validate_every: int = 1000
checkpoint_every: Optional[int] = None
store_all_checkpoints: bool = False
print_every: int = 100
start_at_step: int = 0
num_final_gen_steps: Optional[int] = None
Expand Down
78 changes: 52 additions & 26 deletions src/gflownet/data/qm9.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,88 @@
import sys
import tarfile

import numpy as np
import pandas as pd
import rdkit.Chem as Chem
import torch
from rdkit.Chem import QED, Descriptors
from torch.utils.data import Dataset

from gflownet.utils import sascore


class QM9Dataset(Dataset):
def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9):
def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9):
if h5_file is not None:
self.df = pd.HDFStore(h5_file, "r")["df"]

self.hdf = pd.HDFStore(h5_file, "r")
self.df = self.hdf["df"]
self.is_hdf = True
elif xyz_file is not None:
self.load_tar()
self.df = load_tar(xyz_file)
self.is_hdf = False
else:
raise ValueError("Either h5_file or xyz_file must be provided")
rng = np.random.default_rng(split_seed)
idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required
idcs = np.arange(len(self.df))
rng.shuffle(idcs)
self.target = target
self.targets = targets
if train:
self.idcs = idcs[: int(np.floor(ratio * len(self.df)))]
else:
self.idcs = idcs[int(np.floor(ratio * len(self.df))) :]
self.mol_to_graph = lambda x: x

def get_stats(self, percentile=0.95):
y = self.df[self.target]
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)]
def setup(self, task, ctx):
self.mol_to_graph = ctx.mol_to_graph

def load_tar(self, xyz_file):
f = tarfile.TarFile(xyz_file, "r")
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
all_mols = []
for pt in f:
pt = f.extractfile(pt)
data = pt.read().decode().splitlines()
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
def get_stats(self, target=None, percentile=0.95):
if target is None:
target = self.targets[0]
y = self.df[target]
return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)]

def __len__(self):
return len(self.idcs)

def __getitem__(self, idx):
return (
Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]),
torch.tensor([self.df[self.target][self.idcs[idx]]]).float(),
self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])),
torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(),
)

def terminate(self):
if self.is_hdf:
self.hdf.close()

def convert_h5():
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
f = tarfile.TarFile("qm9.xyz.tar", "r")

def load_tar(xyz_file):
labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"]
f = tarfile.TarFile(xyz_file, "r")
all_mols = []
for pt in f:
pt = f.extractfile(pt) # type: ignore
data = pt.read().decode().splitlines() # type: ignore
all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:])))
df = pd.DataFrame(all_mols, columns=["SMILES"] + labels)
store = pd.HDFStore("qm9.h5", "w")
store["df"] = df
mols = df["SMILES"].map(Chem.MolFromSmiles)
df["qed"] = mols.map(QED.qed)
df["sa"] = mols.map(sascore.calculateScore)
df["mw"] = mols.map(Descriptors.MolWt)
return df


def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"):
"""
Convert `xyz_file` and dump it into `h5_file`
"""
# File obtained from
# https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904
# (from http://quantum-machine.org/datasets/)
df = load_tar(xyz_file)
with pd.HDFStore(h5_file, "w") as store:
store["df"] = df


if __name__ == "__main__":
convert_h5(*sys.argv[1:])
26 changes: 22 additions & 4 deletions src/gflownet/data/sampling_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sqlite3
from collections.abc import Iterable
from copy import deepcopy
from typing import Callable, List
from typing import Callable, List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -40,6 +40,7 @@ def __init__(
log_dir: str = None,
sample_cond_info: bool = True,
random_action_prob: float = 0.0,
det_after: Optional[int] = None,
hindsight_ratio: float = 0.0,
init_train_iter: int = 0,
):
Expand Down Expand Up @@ -99,7 +100,8 @@ def __init__(
self.hindsight_ratio = hindsight_ratio
self.train_it = init_train_iter
self.do_validate_batch = False # Turn this on for debugging

self.iter = 0
self.det_after = det_after
# Slightly weird semantics, but if we're sampling x given some fixed cond info (data)
# then "offline" now refers to cond info and online to x, so no duplication and we don't end
# up with 2*batch_size accidentally
Expand All @@ -122,7 +124,10 @@ def _idx_iterator(self):
if self.stream:
# If we're streaming data, just sample `offline_batch_size` indices
while True:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
if self.offline_batch_size == 0 or len(self.data) == 0:
yield np.arange(0, 0)
else:
yield self.rng.integers(0, len(self.data), self.offline_batch_size)
else:
# Otherwise, figure out which indices correspond to this worker
worker_info = torch.utils.data.get_worker_info()
Expand Down Expand Up @@ -156,6 +161,9 @@ def __len__(self):
return len(self.data)

def __iter__(self):
self.iter += 1
if self.det_after is not None and self.iter > self.det_after:
self.random_action_prob = 0
worker_info = torch.utils.data.get_worker_info()
self._wid = worker_info.id if worker_info is not None else 0
# Now that we know we are in a worker instance, we can initialize per-worker things
Expand All @@ -181,6 +189,7 @@ def __iter__(self):
flat_rewards = (
list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else []
)

trajs = self.algo.create_training_data_from_graphs(
graphs, self.model, cond_info["encoding"][:num_offline], 0
)
Expand Down Expand Up @@ -236,8 +245,13 @@ def __iter__(self):
log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards)
log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward

assert len(trajs) == num_online + num_offline
# Computes some metrics
extra_info = {}
extra_info = {"random_action_prob": self.random_action_prob}
if num_online > 0:
H = sum(i["fwd_logprob"] for i in trajs[num_offline:])
extra_info["entropy"] = -H / num_online
extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]])
if not self.sample_cond_info:
# If we're using a dataset of preferences, the user may want to know the id of the preference
for i, j in zip(trajs, idcs):
Expand Down Expand Up @@ -316,6 +330,10 @@ def __iter__(self):
batch.preferences = cond_info.get("preferences", None)
batch.focus_dir = cond_info.get("focus_dir", None)
batch.extra_info = extra_info
if self.ctx.has_n():
log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs]
batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32)
batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32)
# TODO: we could very well just pass the cond_info dict to construct_batch above,
# and the algo can decide what it wants to put in the batch object

Expand Down
Loading
Loading