From 94a5a4b200e577c210d9fda9b0142d2413219a63 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Tue, 6 Feb 2024 02:05:59 -0500 Subject: [PATCH 01/33] add fixes from the maxent paper --- pyproject.toml | 1 + src/gflownet/data/qm9.py | 55 +-- src/gflownet/data/sampling_iterator.py | 26 +- src/gflownet/envs/frag_mol_env.py | 84 ++++- src/gflownet/envs/graph_building_env.py | 48 ++- src/gflownet/envs/mol_building_env.py | 28 +- src/gflownet/models/bengio2021flow.py | 23 ++ src/gflownet/models/graph_transformer.py | 2 +- src/gflownet/online_trainer.py | 52 ++- src/gflownet/tasks/config.py | 19 +- src/gflownet/tasks/qm9/qm9.py | 25 +- src/gflownet/tasks/qm9_moo.py | 384 ++++++++++++++++++++ src/gflownet/tasks/seh_frag.py | 15 +- src/gflownet/tasks/seh_frag_moo.py | 16 +- src/gflownet/trainer.py | 87 ++++- src/gflownet/utils/conditioning.py | 15 +- src/gflownet/utils/config.py | 1 + src/gflownet/utils/metrics.py | 28 +- src/gflownet/utils/multiobjective_hooks.py | 53 ++- src/gflownet/utils/multiprocessing_proxy.py | 23 +- tests/test_subtb.py | 56 ++- 21 files changed, 929 insertions(+), 112 deletions(-) create mode 100644 src/gflownet/tasks/qm9_moo.py diff --git a/pyproject.toml b/pyproject.toml index d588c636..61aa4d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ dev = [ "types-pkg_resources", # Security pin "gitpython>=3.1.30", + "ruamel.yaml", ] [[project.authors]] diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index b26c29d2..d9a1e217 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -3,12 +3,15 @@ import numpy as np import pandas as pd import rdkit.Chem as Chem +from rdkit.Chem import QED, Descriptors +from rdkit.Chem.rdchem import Mol as RDMol +from gflownet.utils import metrics, sascore import torch from torch.utils.data import Dataset class QM9Dataset(Dataset): - def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_seed=142857, ratio=0.9): + def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): if h5_file is not None: self.df = pd.HDFStore(h5_file, "r")["df"] elif xyz_file is not None: @@ -16,47 +19,55 @@ def __init__(self, h5_file=None, xyz_file=None, train=True, target="gap", split_ rng = np.random.default_rng(split_seed) idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required rng.shuffle(idcs) - self.target = target + self.targets = targets if train: self.idcs = idcs[: int(np.floor(ratio * len(self.df)))] else: self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] + self.mol_to_graph = lambda x: x + + def setup(self, task, ctx): + self.mol_to_graph = ctx.mol_to_graph - def get_stats(self, percentile=0.95): - y = self.df[self.target] + def get_stats(self, target=None, percentile=0.95): + if target is None: + target = self.targets[0] + y = self.df[target] return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] def load_tar(self, xyz_file): - f = tarfile.TarFile(xyz_file, "r") - labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] - all_mols = [] - for pt in f: - pt = f.extractfile(pt) - data = pt.read().decode().splitlines() - all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) - self.df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) + self.df = load_tar(xyz_file) def __len__(self): return len(self.idcs) def __getitem__(self, idx): return ( - Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]]), - torch.tensor([self.df[self.target][self.idcs[idx]]]).float(), + self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), + torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), ) -def convert_h5(): - # File obtained from - # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 - # (from http://quantum-machine.org/datasets/) - f = tarfile.TarFile("qm9.xyz.tar", "r") +def load_tar(xyz_file): labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] + f = tarfile.TarFile(xyz_file, "r") all_mols = [] for pt in f: - pt = f.extractfile(pt) # type: ignore + pt = f.extractfile(pt) # type: ignore3 data = pt.read().decode().splitlines() # type: ignore all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) - store = pd.HDFStore("qm9.h5", "w") - store["df"] = df + mols = df["SMILES"].map(Chem.MolFromSmiles) + df["qed"] = mols.map(QED.qed) + df["sa"] = mols.map(sascore.calculateScore) + df["mw"] = mols.map(Descriptors.MolWt) + return df + + +def convert_h5(): + # File obtained from + # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 + # (from http://quantum-machine.org/datasets/) + df = load_tar("qm9.xyz.tar") + with pd.HDFStore("qm9.h5", "w") as store: + store["df"] = df diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index c546964e..c19bfe00 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -2,7 +2,7 @@ import sqlite3 from collections.abc import Iterable from copy import deepcopy -from typing import Callable, List +from typing import Callable, List, Optional import numpy as np import torch @@ -40,6 +40,7 @@ def __init__( log_dir: str = None, sample_cond_info: bool = True, random_action_prob: float = 0.0, + det_after: Optional[int] = None, hindsight_ratio: float = 0.0, init_train_iter: int = 0, ): @@ -99,7 +100,8 @@ def __init__( self.hindsight_ratio = hindsight_ratio self.train_it = init_train_iter self.do_validate_batch = False # Turn this on for debugging - + self.iter = 0 + self.det_after = det_after # Slightly weird semantics, but if we're sampling x given some fixed cond info (data) # then "offline" now refers to cond info and online to x, so no duplication and we don't end # up with 2*batch_size accidentally @@ -122,7 +124,10 @@ def _idx_iterator(self): if self.stream: # If we're streaming data, just sample `offline_batch_size` indices while True: - yield self.rng.integers(0, len(self.data), self.offline_batch_size) + if self.offline_batch_size == 0 or len(self.data) == 0: + yield np.arange(0, 0) + else: + yield self.rng.integers(0, len(self.data), self.offline_batch_size) else: # Otherwise, figure out which indices correspond to this worker worker_info = torch.utils.data.get_worker_info() @@ -156,6 +161,9 @@ def __len__(self): return len(self.data) def __iter__(self): + self.iter += 1 + if self.det_after is not None and self.iter > self.det_after: + self.random_action_prob = 0 worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 # Now that we know we are in a worker instance, we can initialize per-worker things @@ -181,6 +189,7 @@ def __iter__(self): flat_rewards = ( list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] ) + trajs = self.algo.create_training_data_from_graphs( graphs, self.model, cond_info["encoding"][:num_offline], 0 ) @@ -236,8 +245,13 @@ def __iter__(self): log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) log_rewards[torch.logical_not(is_valid)] = self.illegal_action_logreward + assert len(trajs) == num_online + num_offline # Computes some metrics - extra_info = {} + extra_info = {"random_action_prob": self.random_action_prob} + if num_online > 0: + H = sum(i["fwd_logprob"] for i in trajs[num_offline:]) + extra_info["entropy"] = -H / num_online + extra_info["length"] = np.mean([len(i["traj"]) for i in trajs[num_offline:]]) if not self.sample_cond_info: # If we're using a dataset of preferences, the user may want to know the id of the preference for i, j in zip(trajs, idcs): @@ -315,6 +329,10 @@ def __iter__(self): batch.preferences = cond_info.get("preferences", None) batch.focus_dir = cond_info.get("focus_dir", None) batch.extra_info = extra_info + if self.ctx.has_n(): + log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] + batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) + batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) # TODO: we could very well just pass the cond_info dict to construct_batch above, # and the algo can decide what it wants to put in the batch object diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index ac118441..293b733d 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,10 +1,13 @@ from collections import defaultdict -from typing import List, Tuple +from math import log +from typing import Any, List, Tuple import numpy as np import rdkit.Chem as Chem import torch import torch_geometric.data as gd +import networkx as nx +from scipy import special from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext from gflownet.models import bengio2021flow @@ -85,6 +88,7 @@ def __init__(self, max_frags: int = 9, num_cond_dim: int = 0, fragments: List[Tu GraphActionType.RemoveEdgeAttr, ] self.device = torch.device("cpu") + self.n_counter = NCounter() self.sorted_frags = sorted(list(enumerate(self.frags_mol)), key=lambda x: -x[1].GetNumAtoms()) def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): @@ -355,6 +359,84 @@ def object_to_log_repr(self, g: Graph): """Convert a Graph to a string representation""" return Chem.MolToSmiles(self.graph_to_mol(g)) + def has_n(self) -> bool: + return True + + def log_n(self, g: Graph) -> int: + return self.n_counter(g) + + +class NCounter: + """ + See Appendix D of "Maximum entropy GFlowNets with soft Q-learning" Mohammadpour et al 2024 (https://arxiv.org/abs/2312.14331) for a proof. + Dynamic program to calculate the number of trajectories to a state. + + """ + def __init__(self): + # Hold the log factorial + self.cache = [0.0, 0.0] + + def lfac(self, arg: int): + while arg >= len(self.cache): + self.cache.append(log(len(self.cache)) + self.cache[-1]) + return self.cache[arg] + + def lcomb(self, x, y): + # log c(x, y) = log (x! / (y! (x - y)!)) + assert x >= y + return self.lfac(x) - self.lfac(y) - self.lfac(x - y) + + @staticmethod + def root_tree(og: nx.Graph, x): + g = nx.DiGraph(nx.create_empty_copy(og)) + visited = np.zeros(len(g), bool) + visited[x] = True + q = [x] + while len(q) > 0: # print(i, x) + x = q.pop() + for i in nx.neighbors(og, x): + if not visited[i]: + visited[i] = True + g.add_edge(x, i, **(og.get_edge_data(x, i) | og.get_edge_data(i, x))) + q.append(i) + + return g + + def f(self, g, x): + elem = np.full((len(g),), -1, int) + ways = np.full((len(g),), -1, float) + + def _f(x): + if elem[x] < 0: + e, w = 0, 0 + for i in nx.neighbors(g, x): + e1, w1 = _f(i) + # edge feature + f = len(g.get_edge_data(x, i)) + for i in range(f): + w1 += np.log(e1 + i) + e1 += f + + w = w + w1 + self.lcomb(e + e1, e) + e = e + e1 + + elem[x] = e + 1 + ways[x] = w + return elem[x], ways[x] + + return _f(x)[1] + + def __call__(self, g): + if len(g) == 0: + return 0 + + acc = [] + for i in nx.nodes(g): + rg = self.root_tree(g, i) + x = self.f(rg, i) + acc.append(x) + + return special.logsumexp(acc) def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None): if numiters is None: diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 2cba54cb..4381bd85 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -547,6 +547,7 @@ def __init__( slice_dict[k].to(dev) if k is not None else torch.arange(graphs.num_graphs + 1, device=dev) for k in keys ] self.logprobs = None + self.log_n = None if deduplicate_edge_index and "edge_index" in keys: for idx, k in enumerate(keys): @@ -560,6 +561,8 @@ def detach(self): new.logits = [i.detach() for i in new.logits] if new.logprobs is not None: new.logprobs = [i.detach() for i in new.logprobs] + if new.log_n is not None: + new.log_n = new.log_n.detach() return new def to(self, device): @@ -569,10 +572,28 @@ def to(self, device): self.slice = [i.to(device) for i in self.slice] if self.logprobs is not None: self.logprobs = [i.to(device) for i in self.logprobs] + if self.log_n is not None: + self.log_n = self.log_n.to(device) if self.masks is not None: self.masks = [i.to(device) for i in self.masks] return self + def log_n_actions(self): + if self.log_n is None: + self.log_n = ( + sum( + [ + scatter(m.broadcast_to(i.shape).int().sum(1), b, dim=0, dim_size=self.num_graphs, reduce="sum") + for m, i, b in zip(self.masks, self.logits, self.batch) + ] + ) + .clamp(1) + .float() + .log() + .clamp(1) + ) + return self.log_n + def _compute_batchwise_max( self, x: List[torch.Tensor], @@ -671,8 +692,25 @@ def sample(self) -> List[Tuple[int, int, int]]: u = [torch.rand(i.shape, device=self.dev) for i in self.logits] # Gumbel noise gumbel = [logit - (-noise.log()).log() for logit, noise in zip(self.logits, u)] + + if self.masks is not None: + gumbel_safe = [ + torch.where( + mask == 1, + torch.maximum( + x, + torch.nextafter( + torch.tensor(torch.finfo(x.dtype).min, dtype=x.dtype), torch.tensor(0.0, dtype=x.dtype) + ).to(x.device), + ), + torch.finfo(x.dtype).min, + ) + for x, mask in zip(gumbel, self.masks) + ] + else: + gumbel_safe = gumbel # Take the argmax - return self.argmax(x=gumbel) + return self.argmax(x=gumbel_safe) def argmax( self, @@ -919,3 +957,11 @@ def object_to_log_repr(self, g: Graph) -> str: return json.dumps( [[(i, g.nodes[i]) for i in g.nodes], [(e, g.edges[e]) for e in g.edges]], separators=(",", ":") ) + def has_n(self) -> bool: + return False + + def log_n(self, g) -> float: + return 0.0 + + def traj_log_n(self, traj): + return [self.log_n(g) for g, _ in traj] diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 8c9c0b5d..20c05586 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -257,17 +257,17 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat)) + x = np.zeros((max(1, len(g.nodes)), self.num_node_dim - self.num_rw_feat), dtype=np.float32) x[0, -1] = len(g.nodes) == 0 - add_node_mask = np.ones((x.shape[0], self.num_new_node_values)) + add_node_mask = np.ones((x.shape[0], self.num_new_node_values), dtype=np.float32) if self.max_nodes is not None and len(g.nodes) >= self.max_nodes: add_node_mask *= 0 - remove_node_mask = np.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) - remove_node_attr_mask = np.zeros((x.shape[0], len(self.settable_atom_attrs))) + remove_node_mask = np.zeros((x.shape[0], 1), dtype=np.float32) + (1 if len(g) == 0 else 0) + remove_node_attr_mask = np.zeros((x.shape[0], len(self.settable_atom_attrs)), dtype=np.float32) explicit_valence = {} max_valence = {} - set_node_attr_mask = np.ones((x.shape[0], self.num_node_attr_logits)) + set_node_attr_mask = np.ones((x.shape[0], self.num_node_attr_logits), dtype=np.float32) bridges = set(nx.bridges(g)) if not len(g.nodes): set_node_attr_mask *= 0 @@ -326,14 +326,14 @@ def graph_to_Data(self, g: Graph) -> gd.Data: s, e = self.atom_attr_logit_slice["expl_H"] set_node_attr_mask[i, s:e] = 0 - remove_edge_mask = np.zeros((len(g.edges), 1)) + remove_edge_mask = np.zeros((len(g.edges), 1), dtype=np.float32) for i, e in enumerate(g.edges): if e not in bridges: remove_edge_mask[i] = 1 - edge_attr = np.zeros((len(g.edges) * 2, self.num_edge_dim)) - set_edge_attr_mask = np.zeros((len(g.edges), self.num_edge_attr_logits)) - remove_edge_attr_mask = np.zeros((len(g.edges), len(self.bond_attrs))) + edge_attr = np.zeros((len(g.edges) * 2, self.num_edge_dim), dtype=np.float32) + set_edge_attr_mask = np.zeros((len(g.edges), self.num_edge_attr_logits), dtype=np.float32) + remove_edge_attr_mask = np.zeros((len(g.edges), len(self.bond_attrs)), dtype=np.float32) for i, e in enumerate(g.edges): ad = g.edges[e] for k, sl in zip(self.bond_attrs, self.bond_attr_slice): @@ -368,17 +368,21 @@ def graph_to_Data(self, g: Graph) -> gd.Data: and explicit_valence[u] + 1 <= max_valence[u] and explicit_valence[v] + 1 <= max_valence[v] ) - ] + ], + dtype=np.float32, ) data = dict( x=x, edge_index=edge_index, edge_attr=edge_attr, non_edge_index=non_edge_index.astype(np.int64).reshape((-1, 2)).T, - stop_mask=np.ones((1, 1)) * (len(g.nodes) > 0), # Can only stop if there's at least a node + stop_mask=np.ones((1, 1), dtype=np.float32) + * (len(g.nodes) > 0), # Can only stop if there's at least a node add_node_mask=add_node_mask, set_node_attr_mask=set_node_attr_mask, - add_edge_mask=np.ones((non_edge_index.shape[0], 1)), # Already filtered by checking for valence + add_edge_mask=np.ones( + (non_edge_index.shape[0], 1), dtype=np.float32 + ), # Already filtered by checking for valence set_edge_attr_mask=set_edge_attr_mask, remove_node_mask=remove_node_mask, remove_node_attr_mask=remove_node_attr_mask, diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index 9797e15f..dcd9894f 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -105,6 +105,29 @@ ["c1ncc2nc[nH]c2n1", [2, 6]], ] +""" +18 fragments from "Towards Understanding and Improving GFlowNet Training" by Shen et al. (https://arxiv.org/abs/2305.07170) +""" +FRAGMENTS_18 = [ + ["CO", [1, 0]], + ["O=c1[nH]cnc2[nH]cnc12", [3, 6]], + ["S", [0, 0]], + ["C1CNCCN1", [2, 5]], + ["c1cc[nH+]cc1", [3, 1]], + ["c1ccccc1", [0, 2]], + ["C1CCCCC1", [0, 2]], + ["CC(C)C", [1, 2]], + ["C1CCOCC1", [0, 2]], + ["c1cn[nH]c1", [4, 0]], + ["C1CCNC1", [2, 0]], + ["c1cncnc1", [0, 1]], + ["O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4]], + ["c1ccncc1", [1, 0]], + ["O=c1nccc[nH]1", [6, 3]], + ["O=c1cc[nH]c(=O)[nH]1", [2, 4]], + ["C1CCOC1", [2, 4]], + ["C1CCNCC1", [1, 0]], +] class MPNNet(nn.Module): def __init__( diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..79d42cbf 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -169,7 +169,7 @@ def __init__( self, env_ctx, cfg: Config, - num_graph_out=1, + num_graph_out=2, do_bck=False, ): """See `GraphTransformer` for argument values""" diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 98791be5..b7a811ac 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -16,6 +16,12 @@ from .trainer import GFNTrainer +def model_grad_norm(model): + x = 0 + for i in self.model.parameters(): + if i.grad is not None: + x += (i.grad * i.grad).sum() + return torch.sqrt(x) class StandardOnlineTrainer(GFNTrainer): def setup_model(self): @@ -43,6 +49,22 @@ def setup_data(self): self.training_data = [] self.test_data = [] + def _opt(self, params, lr=None, momentum=None): + if lr is None: + lr = self.cfg.opt.learning_rate + if momentum is None: + momentum = self.cfg.opt.momentum + if self.cfg.opt.opt == "adam": + return torch.optim.Adam( + params, + lr, + (momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + + raise NotImplementedError(f"{self.opt.opt} is not implemented") + def setup(self): super().setup() self.offline_ratio = 0 @@ -55,14 +77,8 @@ def setup(self): else: Z_params = [] non_Z_params = list(self.model.parameters()) - self.opt = torch.optim.Adam( - non_Z_params, - self.cfg.opt.learning_rate, - (self.cfg.opt.momentum, 0.999), - weight_decay=self.cfg.opt.weight_decay, - eps=self.cfg.opt.adam_eps, - ) - self.opt_Z = torch.optim.Adam(Z_params, self.cfg.algo.tb.Z_learning_rate, (0.9, 0.999)) + self.opt = self._opt(non_Z_params) + self.opt_Z = self._opt(Z_params, self.cfg.algo.tb.Z_learning_rate, 0.9) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) self.lr_sched_Z = torch.optim.lr_scheduler.LambdaLR( self.opt_Z, lambda steps: 2 ** (-steps / self.cfg.algo.tb.Z_lr_decay) @@ -77,7 +93,8 @@ def setup(self): self.mb_size = self.cfg.algo.global_batch_size self.clip_grad_callback = { "value": lambda params: torch.nn.utils.clip_grad_value_(params, self.cfg.opt.clip_grad_param), - "norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), + "norm": lambda params: [torch.nn.utils.clip_grad_norm_(p, self.cfg.opt.clip_grad_param) for p in params], + "total_norm": lambda params: torch.nn.utils.clip_grad_norm_(params, self.cfg.opt.clip_grad_param), "none": lambda x: None, }[self.cfg.opt.clip_grad_type] @@ -85,17 +102,20 @@ def setup(self): git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] self.cfg.git_hash = git_hash - os.makedirs(self.cfg.log_dir, exist_ok=True) - print("\n\nHyperparameters:\n") yaml = OmegaConf.to_yaml(self.cfg) - print(yaml) - with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w") as f: + os.makedirs(self.cfg.log_dir, exist_ok=True) + if self.print: + print("\n\nHyperparameters:\n") + print(yaml) + with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f: f.write(yaml) - + def step(self, loss: Tensor): loss.backward() - for i in self.model.parameters(): - self.clip_grad_callback(i) + with torch.no_grad(): + g0 = model_grad_norm(model) + self.clip_grad_callback(self.model.parameters()) + g1 = model_grad_norm(model) self.opt.step() self.opt.zero_grad() self.opt_Z.step() diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index a9f6ac3f..44c997ee 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -4,7 +4,7 @@ @dataclass class SEHTaskConfig: - pass # SEH just uses a temperature conditional + reduced_frag: bool = False @dataclass @@ -57,8 +57,25 @@ class QM9TaskConfig: model_path: str = "./data/qm9/qm9_model.pt" +@dataclass +class QM9MOOTaskConfig: + use_steer_thermometer: bool = False + preference_type: Optional[str] = "dirichlet" + focus_type: Optional[str] = None + focus_dirs_listed: Optional[List[List[float]]] = None + focus_cosim: float = 0.0 + focus_limit_coef: float = 1.0 + focus_model_training_limits: Optional[Tuple[int, int]] = None + focus_model_state_space_res: Optional[int] = None + max_train_it: Optional[int] = None + n_valid: int = 15 + n_valid_repeats: int = 128 + objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) + + @dataclass class TasksConfig: qm9: QM9TaskConfig = QM9TaskConfig() + qm9_moo: QM9MOOTaskConfig = QM9MOOTaskConfig() seh: SEHTaskConfig = SEHTaskConfig() seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index e5b1d29a..dc142528 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -31,11 +31,12 @@ def __init__( ): self._wrap_model = wrap_model self.rng = rng - self.models = self.load_task_models(cfg.task.qm9.model_path) + self.models = self.load_task_models(cfg.task.qm9.model_path, torch.device(cfg.device)) self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg, rng) + self.num_cond_dim = self.temperature_conditional.encoding_size() # TODO: fix interface - self._min, self._max, self._percentile_95 = self.dataset.get_stats(percentile=0.05) # type: ignore + self._min, self._max, self._percentile_95 = self.dataset.get_stats("gap", percentile=0.05) # type: ignore self._width = self._max - self._min self._rtrans = "unit+95p" # TODO: hyperparameter @@ -60,12 +61,12 @@ def inverse_flat_reward_transform(self, rp): elif self._rtrans == "unit+95p": return (1 - rp + (1 - self._percentile_95)) * self._width + self._min - def load_task_models(self, path): + def load_task_models(self, path, device): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? - state_dict = torch.load(path) + state_dict = torch.load(path, map_location=device) gap_model.load_state_dict(state_dict) - gap_model.cuda() + gap_model.to(device) gap_model, self.device = self._wrap_model(gap_model, send_to_device=True) return {"mxmnet_gap": gap_model} @@ -112,7 +113,10 @@ def set_default_hps(self, cfg: Config): def setup_env_context(self): self.ctx = MolBuildingEnvContext( - ["C", "N", "F", "O"], expl_H_range=[0, 1, 2, 3], num_cond_dim=32, allow_5_valence_nitrogen=True + ["C", "N", "F", "O"], + expl_H_range=[0, 1, 2, 3], + num_cond_dim=self.task.num_cond_dim, + allow_5_valence_nitrogen=True, ) # Note: we only need the allow_5_valence_nitrogen flag because of how we generate trajectories # from the dataset. For example, consider tue Nitrogen atom in this: C[NH+](C)C, when s=CN(C)C, if the action @@ -122,8 +126,8 @@ def setup_env_context(self): # (PR #98) this edge case is the only case where the ordering in which attributes are set can matter. def setup_data(self): - self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, target="gap") - self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, target="gap") + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=["gap"]) + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=["gap"]) def setup_task(self): self.task = QM9GapTask( @@ -133,6 +137,11 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) + def setup(self): + super().setup() + self.training_data.setup(self.task, self.ctx) + self.test_data.setup(self.task, self.ctx) + def main(): """Example of how this model can be run.""" diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py new file mode 100644 index 00000000..849f8baf --- /dev/null +++ b/src/gflownet/tasks/qm9_moo.py @@ -0,0 +1,384 @@ +import os +import pathlib +from typing import Any, Callable, Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from rdkit.Chem.rdchem import Mol as RDMol +from ruamel.yaml import YAML +from torch import Tensor +from torch.utils.data import Dataset +from rdkit.Chem import QED, Descriptors +from gflownet.utils import metrics, sascore +from gflownet.algo.envelope_q_learning import EnvelopeQLearning +from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce + +import gflownet.models.mxmnet as mxmnet +from gflownet.config import Config +from gflownet.data.qm9 import QM9Dataset +from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.online_trainer import StandardOnlineTrainer +from gflownet.trainer import FlatRewards, GFNTask, RewardScalar +from gflownet.utils import metrics +from gflownet.utils.conditioning import ( + FocusRegionConditional, + MultiObjectiveWeightedPreferences, + TemperatureConditional, +) +from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer +from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook + + +def safe(f, x, default): + try: + return f(x) + except Exception: + return default + + +class QM9GapMOOTask(QM9GapTask): + """Sets up a multiobjective task where the rewards are (functions of): + - the the binding energy of a molecule to Soluble Epoxide Hydrolases. + - its QED + - its synthetic accessibility + - its molecular weight + + The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. + """ + + def __init__( + self, + dataset: Dataset, + cfg: Config, + rng: np.random.Generator = None, + wrap_model: Callable[[nn.Module], nn.Module] = None, + ): + super().__init__(dataset, cfg, rng, wrap_model) + self.cfg = cfg + mcfg = self.cfg.task.qm9_moo + self.objectives = cfg.task.qm9_moo.objectives + self.dataset = dataset + if self.cfg.cond.focus_region.focus_type is not None: + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) + else: + self.focus_cond = None + self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) + self.temperature_sample_dist = cfg.cond.temperature.sample_dist + self.temperature_dist_params = cfg.cond.temperature.dist_params + self.num_thermometer_dim = cfg.cond.temperature.num_thermometer_dim + self.num_cond_dim = ( + self.temperature_conditional.encoding_size() + + self.pref_cond.encoding_size() + + (self.focus_cond.encoding_size() if self.focus_cond is not None else 0) + ) + assert set(self.objectives) <= {"gap", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) + + def flat_reward_transform(self, y: Tensor) -> FlatRewards: + assert y.shape[-1] == len(self.objectives) + if len(y.shape) == 1: + y = y[None, :] + assert len(y.shape) == 2 + + flat_r = [] + for i, obj in enumerate(self.objectives): + preds = y[:, i] + if obj == "gap": + preds = super().flat_reward_transform(preds) + elif obj == "qed": + pass + elif obj == "sa": + preds = (10 - preds) / 9 # Turn into a [0-1] reward + elif obj == "mw": + preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + else: + raise ValueError(f"{obj} not known") + flat_r.append(preds) + return FlatRewards(torch.stack(flat_r, dim=1)) + + def inverse_flat_reward_transform(self, rp): + assert rp.shape[-1] == len(self.objectives) + if len(rp.shape) == 1: + rp = rp[None, :] + assert len(rp.shape) == 2 + + flat_r = [] + for i, obj in enumerate(self.objectives): + preds = rp[:, i] + if obj == "qed": + preds = super().inverse_flat_reward_transform(preds) + elif obj == "qed": + pass + elif obj == "sa": + preds = 10 - 9 * preds + elif obj == "mw": + preds = 300 - 700 * (preds - 1) + else: + raise ValueError(f"{obj} not known") + flat_r.append(preds) + + return FlatRewards(torch.stack(flat_r, dim=1)) + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + cond_info = super().sample_conditional_information(n, train_it) + pref_ci = self.pref_cond.sample(n) + focus_ci = ( + self.focus_cond.sample(n, train_it) if self.focus_cond is not None else {"encoding": torch.zeros(n, 0)} + ) + cond_info = { + **cond_info, + **pref_ci, + **focus_ci, + "encoding": torch.cat([cond_info["encoding"], pref_ci["encoding"], focus_ci["encoding"]], dim=1), + } + return cond_info + + def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor]: + """ + Encode conditional information at validation-time + We use the maximum temperature beta for inference + Args: + steer_info: Tensor of shape (Batch, 2 * n_objectives) containing the preferences and focus_dirs + in that order + Returns: + Dict[str, Tensor]: Dictionary containing the encoded conditional information + """ + n = len(steer_info) + if self.temperature_sample_dist == "constant": + beta = torch.ones(n) * self.temperature_dist_params[0] + beta_enc = torch.zeros((n, self.num_thermometer_dim)) + else: + beta = torch.ones(n) * self.temperature_dist_params[-1] + beta_enc = torch.ones((n, self.num_thermometer_dim)) + + assert len(beta.shape) == 1, f"beta should be of shape (Batch,), got: {beta.shape}" + + # TODO: positional assumption here, should have something cleaner + preferences = steer_info[:, : len(self.objectives)].float() + focus_dir = steer_info[:, len(self.objectives) :].float() + + preferences_enc = self.pref_cond.encode(preferences) + if self.focus_cond is not None: + focus_enc = self.focus_cond.encode(focus_dir) + encoding = torch.cat([beta_enc, preferences_enc, focus_enc], 1).float() + else: + encoding = torch.cat([beta_enc, preferences_enc], 1).float() + return { + "beta": beta, + "encoding": encoding, + "preferences": preferences, + "focus_dir": focus_dir, + } + + def relabel_condinfo_and_logrewards( + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + ): + # TODO: we seem to be relabeling tensors in place, could that cause a problem? + if self.focus_cond is None: + raise NotImplementedError("Hindsight relabeling only implemented for focus conditioning") + if self.focus_cond.cfg.focus_type is None: + return cond_info, log_rewards + # only keep hindsight_idxs that actually correspond to a violated constraint + _, in_focus_mask = metrics.compute_focus_coef( + flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + ) + out_focus_mask = torch.logical_not(in_focus_mask) + hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] + + # relabels the focus_dirs and log_rewards + cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) + + preferences_enc = self.pref_cond.encode(cond_info["preferences"]) + focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) + cond_info["encoding"] = torch.cat( + [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 + ) + + log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) + return cond_info, log_rewards + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + if isinstance(flat_reward, list): + if isinstance(flat_reward[0], Tensor): + flat_reward = torch.stack(flat_reward) + else: + flat_reward = torch.tensor(flat_reward) + + scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) + focused_reward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + if self.focus_cond is not None + else scalarized_reward + ) + tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) + return RewardScalar(tempered_reward) + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + graphs = [mxmnet.mol2graph(i) for i in mols] + assert len(graphs) == len(mols) + is_valid = torch.tensor([i is not None for i in graphs]).bool() + valid_graphs = [g for g in graphs if g is not None] + valid_mols = [m for m, g in zip(mols, graphs) if g is not None] + assert len(valid_mols) == len(valid_graphs) + if not is_valid.any(): + return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid + else: + flat_r: List[Tensor] = [] + for obj in self.objectives: + if obj == "gap": + batch = gd.Batch.from_data_list(valid_graphs) + batch.to(self.device) + preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] + preds[preds.isnan()] = 1 + elif obj == "qed": + preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) + elif obj == "sa": + preds = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) + elif obj == "mw": + preds = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) + else: + raise ValueError(f"{obj} not known") + flat_r.append(preds) + assert len(preds) == len(valid_graphs), f"{len(preds)} != {len(valid_graphs)} for obj {obj}" + + flat_rewards = torch.stack(flat_r, dim=1) + return self.flat_reward_transform(flat_rewards), is_valid + + +class QM9MOOTrainer(QM9GapTrainer): + task: QM9GapMOOTask + ctx: MolBuildingEnvContext + + def set_default_hps(self, cfg: Config): + super().set_default_hps(cfg) + cfg.algo.sampling_tau = 0.95 + # We use a fixed set of preferences as our "validation set", so we must disable the preference (cond_info) + # sampling and set the offline ratio to 1 + cfg.algo.valid_sample_cond_info = False + cfg.algo.valid_offline_ratio = 1 + + def setup_task(self): + self.task = QM9GapMOOTask( + dataset=self.training_data, + cfg=self.cfg, + rng=self.rng, + wrap_model=self._wrap_for_mp, + ) + + def setup(self): + super().setup() + # self.sampling_hooks.append( + # MultiObjectiveStatsHook( + # 256, + # self.cfg.log_dir, + # compute_igd=True, + # compute_hvi=False, + # compute_hsri=False, + # compute_normed=False, + # compute_pc_entropy=True, + # compute_focus_accuracy=True if self.cfg.task.qm9_moo.focus_type is not None else False, + # focus_cosim=self.cfg.task.qm9_moo.focus_cosim, + # ) + # ) + # self.to_close.append(self.sampling_hooks[-1].keep_alive) + # instantiate preference and focus conditioning vectors for validation + + tcfg = self.cfg.task.qm9_moo + n_obj = len(tcfg.objectives) + + # making sure hyperparameters for preferences and focus regions are consistent + if not ( + tcfg.focus_type is None + or tcfg.focus_type == "centered" + or (type(tcfg.focus_type) is list and len(tcfg.focus_type) == 1) + ): + assert tcfg.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " + f"and preference_type={tcfg.preference_type}" + ) + + if type(tcfg.focus_type) is list and len(tcfg.focus_type) > 1: + n_valid = len(tcfg.focus_type) + else: + n_valid = tcfg.n_valid + + # preference vectors + if tcfg.preference_type is None: + valid_preferences = np.ones((n_valid, n_obj)) + elif tcfg.preference_type == "dirichlet": + valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") + elif tcfg.preference_type == "seeded_single": + seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) + valid_preferences = seeded_prefs[0].reshape((1, n_obj)) + self.task.seeded_preference = valid_preferences[0] + elif tcfg.preference_type == "seeded_many": + valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) + else: + raise NotImplementedError(f"Unknown preference type {self.cfg.task.qm9_moo.preference_type}") + + # TODO: this was previously reported, would be nice to serialize it + # hps["fixed_focus_dirs"] = ( + # np.unique(self.task.fixed_focus_dirs, axis=0).tolist() if self.task.fixed_focus_dirs is not None else None + # ) + if self.task.focus_cond is not None: + assert self.task.focus_cond.valid_focus_dirs.shape == ( + n_valid, + n_obj, + ), ( + "Invalid shape for valid_preferences, " + f"{self.task.focus_cond.valid_focus_dirs.shape} != ({n_valid}, {n_obj})" + ) + + # combine preferences and focus directions (fixed focus cosim) since they could be used together + # (not either/or). TODO: this relies on positional assumptions, should have something cleaner + valid_cond_vector = np.concatenate([valid_preferences, self.task.focus_cond.valid_focus_dirs], axis=1) + else: + valid_cond_vector = valid_preferences + + self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) + self.valid_sampling_hooks.append(self._top_k_hook) + + self.algo.task = self.task + + def setup_data(self): + self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=self.cfg.task.qm9_moo.objectives) + self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=self.cfg.task.qm9_moo.objectives) + + def build_callbacks(self): + # We use this class-based setup to be compatible with the DeterminedAI API, but no direct + # dependency is required. + parent = self + + class TopKMetricCB: + def on_validation_end(self, metrics: Dict[str, Any]): + top_k = parent._top_k_hook.finalize() + for i in range(len(top_k)): + metrics[f"topk_rewards_{i}"] = top_k[i] + print("validation end", metrics) + + return {"topk": TopKMetricCB()} + + def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + if self.task.focus_cond is not None: + self.task.focus_cond.step_focus_model(batch, train_it) + return super().train_batch(batch, epoch_idx, batch_idx, train_it) + + def _save_state(self, it): + if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: + self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) + return super()._save_state(it) + + +class RepeatedCondInfoDataset: + def __init__(self, cond_info_vectors, repeat): + self.cond_info_vectors = cond_info_vectors + self.repeat = repeat + + def __len__(self): + return len(self.cond_info_vectors) * self.repeat + + def __getitem__(self, idx): + assert 0 <= idx < len(self) + return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 52c63b11..ba79adcf 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -108,10 +108,11 @@ class LittleSEHDataset(Dataset): To turn on, self `cfg.algo.offline_ratio > 0`""" - def __init__(self) -> None: + def __init__(self, smis) -> None: super().__init__() self.props: List[Tensor] = [] self.mols: List[Graph] = [] + self.smis = smis def setup(self, task, ctx): rdmols = [Chem.MolFromSmiles(i) for i in SOME_MOLS] @@ -173,10 +174,18 @@ def setup_task(self): def setup_data(self): super().setup_data() - self.training_data = LittleSEHDataset() + if self.cfg.task.seh.reduced_frag: + # The examples don't work with the 18 frags + self.training_data = LittleSEHDataset([]) + else: + self.training_data = LittleSEHDataset(SOME_MOLS) def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) + self.ctx = FragMolBuildingEnvContext( + max_frags=self.cfg.algo.max_nodes, + num_cond_dim=self.task.num_cond_dim, + fragments=bengio2021flow.FRAGMENTS_18 if self.cfg.task.seh.reduced_frag else bengio2021flow.FRAGMENTS, + ) def setup(self): super().setup() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index f8b432a1..3e5624e9 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -163,18 +163,23 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] + assert len(graphs) == len(mols) is_valid = torch.tensor([i is not None for i in graphs]).bool() + valid_graphs = [g for g in graphs if g is not None] + valid_mols = [m for m, g in zip(mols, graphs) if g is not None] + assert len(valid_mols) == len(valid_graphs) if not is_valid.any(): return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid else: flat_r: List[Tensor] = [] if "seh" in self.objectives: - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) + batch = gd.Batch.from_data_list(valid_graphs) batch.to(self.device) seh_preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8 seh_preds[seh_preds.isnan()] = 0 flat_r.append(seh_preds) + assert len(seh_preds) == len(valid_graphs), f"{len(seh_preds)} != {len(valid_graphs)}" def safe(f, x, default): try: @@ -183,18 +188,21 @@ def safe(f, x, default): return default if "qed" in self.objectives: - qeds = torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v.item()]) + qeds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) flat_r.append(qeds) + assert len(qeds) == len(valid_graphs), f"{len(qeds)} != {len(valid_graphs)}" if "sa" in self.objectives: - sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i, v in zip(mols, is_valid) if v.item()]) + sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) sas = (10 - sas) / 9 # Turn into a [0-1] reward flat_r.append(sas) + assert len(sas) == len(valid_graphs), f"{len(sas)} != {len(valid_graphs)}" if "mw" in self.objectives: - molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i, v in zip(mols, is_valid) if v.item()]) + molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 flat_r.append(molwts) + assert len(molwts) == len(valid_graphs), f"{len(molwts)} != {len(valid_graphs)}" flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 55e0159b..a6d4710c 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -31,6 +31,10 @@ class GFNAlgorithm: + updates: int = 0 + + def step(self): + self.updates += 1 def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -91,7 +95,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: class GFNTrainer: - def __init__(self, hps: Dict[str, Any]): + def __init__(self, hps: Dict[str, Any], print=True): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters @@ -101,6 +105,8 @@ def __init__(self, hps: Dict[str, Any]): device: torch.device The torch device of the main worker. """ + self.print = print + self.to_close = [] # self.setup should at least set these up: self.training_data: Dataset self.test_data: Dataset @@ -129,7 +135,7 @@ def __init__(self, hps: Dict[str, Any]): # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data - self.sampling_hooks: List[Callable] = [] + self.sampling_hooks: List[Callable] = [RewardStats()] self.valid_sampling_hooks: List[Callable] = [] # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False @@ -173,12 +179,13 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - placeholder = mp_object_wrapper( + placeholder, keepalive = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) + self.to_close.append(keepalive) return placeholder, torch.device("cpu") else: return obj, self.device @@ -202,6 +209,7 @@ def build_training_data_loader(self) -> DataLoader: ratio=self.cfg.algo.offline_ratio, log_dir=str(pathlib.Path(self.cfg.log_dir) / "train"), random_action_prob=self.cfg.algo.train_random_action_prob, + det_after=self.cfg.algo.train_det_after, hindsight_ratio=self.cfg.replay.hindsight_ratio, ) for hook in self.sampling_hooks: @@ -272,11 +280,14 @@ def build_final_data_loader(self) -> DataLoader: ) def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: int) -> Dict[str, Any]: + tick = time.time() + self.model.train() try: loss, info = self.algo.compute_batch_losses(self.model, batch) if not torch.isfinite(loss): raise ValueError("loss is not finite") step_info = self.step(loss) + self.algo.step() if self._validate_parameters and not all([torch.isfinite(i).all() for i in self.model.parameters()]): raise ValueError("parameters are not finite") except ValueError as e: @@ -288,12 +299,16 @@ def train_batch(self, batch: gd.Batch, epoch_idx: int, batch_idx: int, train_it: info.update(step_info) if hasattr(batch, "extra_info"): info.update(batch.extra_info) + info["train_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} def evaluate_batch(self, batch: gd.Batch, epoch_idx: int = 0, batch_idx: int = 0) -> Dict[str, Any]: + tick = time.time() + self.model.eval() loss, info = self.algo.compute_batch_losses(self.model, batch) if hasattr(batch, "extra_info"): info.update(batch.extra_info) + info["eval_time"] = time.time() - tick return {k: v.item() if hasattr(v, "item") else v for k, v in info.items()} def run(self, logger=None): @@ -316,7 +331,14 @@ def run(self, logger=None): start = self.cfg.start_at_step + 1 num_training_steps = self.cfg.num_training_steps logger.info("Starting training") + start_time = time.time() for it, batch in zip(range(start, 1 + num_training_steps), cycle(train_dl)): + # the memory fragmentation or allocation keeps growing, how often should we clean up? + # is changing the allocation strategy helpful? + + if it % 1024 == 0: + gc.collect() + torch.cuda.empty_cache() epoch_idx = it // epoch_length batch_idx = it % epoch_length if self.replay_buffer is not None and len(self.replay_buffer) < self.replay_buffer.warmup: @@ -325,6 +347,8 @@ def run(self, logger=None): ) continue info = self.train_batch(batch.to(self.device), epoch_idx, batch_idx, it) + info["time_spent"] = time.time() - start_time + start_time = time.time() self.log(info, it, "train") if it % self.print_every == 0: logger.info(f"iteration {it} : " + " ".join(f"{k}:{v:.2f}" for k, v in info.items())) @@ -344,30 +368,65 @@ def run(self, logger=None): self._save_state(num_training_steps) num_final_gen_steps = self.cfg.num_final_gen_steps + final_info = {} if num_final_gen_steps: logger.info(f"Generating final {num_final_gen_steps} batches ...") for it, batch in zip( - range(num_training_steps, num_training_steps + num_final_gen_steps + 1), + range(num_training_steps + 1, num_training_steps + num_final_gen_steps + 1), cycle(final_dl), ): - pass - logger.info("Final generation steps completed.") + if hasattr(batch, "extra_info"): + for k, v in batch.extra_info.items(): + if k not in final_info: + final_info[k] = [] + if hasattr(v, "item"): + v = v.item() + final_info[k].append(v) + if it % self.print_every == 0: + logger.info(f"Generating mols {it - num_training_steps}/{num_final_gen_steps}") + final_info = {k: np.mean(v) for k, v in final_info.items()} + + logger.info("Final generation steps completed - " + " ".join(f"{k}:{v:.2f}" for k, v in final_info.items())) + self.log(final_info, num_training_steps, "final") + + # for pypy and other GC havers + del train_dl + del valid_dl + if self.cfg.num_final_gen_steps: + del final_dl def _save_state(self, it): - torch.save( - { - "models_state_dict": [self.model.state_dict()], - "cfg": self.cfg, - "step": it, - }, - open(pathlib.Path(self.cfg.log_dir) / "model_state.pt", "wb"), - ) + state = { + "models_state_dict": [self.model.state_dict()], + "cfg": self.cfg, + "step": it, + } + if self.sampling_model is not self.model: + state["sampling_model_state_dict"] = [self.sampling_model.state_dict()] + fn = pathlib.Path(self.cfg.log_dir) / "model_state.pt" + with open(fn, "wb") as fd: + torch.save( + state, + fd, + ) + shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") def log(self, info, index, key): if not hasattr(self, "_summary_writer"): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) + + def close(self): + while len(self.to_close) > 0: + try: + i = self.to_close.pop() + i.close() + except Exception as e: + print(e) + + def __del__(self): + self.close() def cycle(it): diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index 0279167c..0e16cf15 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -53,9 +53,11 @@ def sample(self, n): cfg = self.cfg.cond.temperature beta = None if cfg.sample_dist == "constant": - assert isinstance(cfg.dist_params[0], float) - beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) - beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + if isinstance(cfg.dist_params[0], (float, int, np.int64, np.int32)): + beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) + beta_enc = torch.zeros((n, cfg.num_thermometer_dim)) + else: + raise ValueError(f"{cfg.dist_params[0]} is not a float)") else: if cfg.sample_dist == "gamma": loc, scale = cfg.dist_params @@ -102,11 +104,11 @@ def sample(self, n): elif self.cfg.preference_type == "seeded": preferences = torch.tensor(self.seeded_prefs).float().repeat(n, 1) elif self.cfg.preference_type == "dirichlet_exponential": - a = np.random.dirichlet([1] * self.num_objectives, n) + a = np.random.dirichlet([self.cfg.preference_param] * self.num_objectives, n) b = np.random.exponential(1, n)[:, None] preferences = Dirichlet(torch.tensor(a * b)).sample([1])[0].float() elif self.cfg.preference_type == "dirichlet": - m = Dirichlet(torch.FloatTensor([1.0] * self.num_objectives)) + m = Dirichlet(torch.FloatTensor([self.cfg.preference_param] * self.num_objectives)) preferences = m.sample([n]) else: raise ValueError(f"Unknown preference type {self.cfg.preference_type}") @@ -114,7 +116,8 @@ def sample(self, n): return {"preferences": preferences, "encoding": self.encode(preferences)} def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: - scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30).log() + # NO LOG NO LOG NO LOG NO LOG NO LOG NO LOG + scalar_logreward = (flat_reward * cond_info["preferences"]).sum(1).clamp(min=1e-30) assert len(scalar_logreward.shape) == 1, f"scalar_logreward should be a 1D array, got {scalar_logreward.shape}" return scalar_logreward diff --git a/src/gflownet/utils/config.py b/src/gflownet/utils/config.py index db3d3905..54d0660d 100644 --- a/src/gflownet/utils/config.py +++ b/src/gflownet/utils/config.py @@ -47,6 +47,7 @@ class WeightedPreferencesConfig: - None: All rewards equally weighted""" preference_type: Optional[str] = "dirichlet" + preference_param: Optional[float] = 1.5 @dataclass diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index cc37c127..5638da28 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -522,6 +522,20 @@ def inv_transform(self, arr): return self.scale * arr + self.loc +def chunkedsim(thresh, fp, mode_fps, delta=16): + """ + Equivalent to `all(DataStructs.BulkTanimotoSimilarity(fp, mode_fps) < thresh)` + """ + assert delta > 0 + s = 0 + n = len(mode_fps) + while s < n: + e = min(s + delta, n) + for i in DataStructs.BulkTanimotoSimilarity(fp, mode_fps[s:e]): + if i >= thresh: + return False + s = e + return True # Should be calculated per preference def compute_diverse_top_k(smiles, rewards, k, thresh=0.7): # mols is a list of (reward, mol) @@ -551,7 +565,7 @@ def get_topk(rewards, k): Rewards obtained after taking the convex combination. Shape: number_of_preferences x number_of_samples k : int - Tok k value + Top k value Returns ---------- @@ -564,6 +578,18 @@ def get_topk(rewards, k): mean_topk = torch.mean(topk_rewards.mean(-1)) return mean_topk +def top_k_diversity(fps, r, K): + x = [] + for i in np.argsort(r)[::-1]: + y = fps[i] + if y is None: + continue + x.append(y) + if len(x) >= K: + break + s = np.array([DataStructs.BulkTanimotoSimilarity(i, x) for i in x]) + return (np.sum(s) - len(x)) / (len(x) * len(x) - len(x)) # substract the diagonal + if __name__ == "__main__": # Example for 2 dimensions diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 4862c6c7..95c72f1f 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -10,6 +10,7 @@ from torch import Tensor from gflownet.utils import metrics +from gflownet.utils.multiprocessing_proxy import KeepAlive class MultiObjectiveStatsHook: @@ -54,9 +55,7 @@ def __init__( self.log_path = pathlib.Path(log_dir) / "pareto.pt" self.pareto_thread = threading.Thread(target=self._run_pareto_accumulation, daemon=True) self.pareto_thread.start() - - def __del__(self): - self.stop.set() + self.keep_alive = KeepAlive(self.stop) def _hsri(self, x): assert x.ndim == 2, "x should have shape (num points, num objectives)" @@ -71,15 +70,18 @@ def _hsri(self, x): def _run_pareto_accumulation(self): num_updates = 0 - while not self.stop.is_set(): + timeouts = 0 + while not self.stop.is_set() or timeouts < 200: try: r, smi, owid = self.pareto_queue.get(block=True, timeout=1) except queue.Empty: + timeouts += 1 continue except ConnectionError as e: print("Pareto Accumulation thread Queue ConnectionError", e) break + timeouts = 0 # accumulates pareto fronts across batches if self.pareto_front is None: p = self.pareto_front = r @@ -108,14 +110,19 @@ def _run_pareto_accumulation(self): if num_updates % self.save_every == 0: if self.pareto_queue.qsize() > 10: print("Warning: pareto metrics computation lagging") - torch.save( - { - "pareto_front": self.pareto_front, - "pareto_metrics": list(self.pareto_metrics), - "pareto_front_smi": self.pareto_front_smi, - }, - open(self.log_path, "wb"), - ) + self._save() + self._save() + + def _save(self): + with open(self.log_path, "wb") as fd: + torch.save( + { + "pareto_front": self.pareto_front, + "pareto_metrics": list(self.pareto_metrics), + "pareto_front_smi": self.pareto_front_smi, + }, + fd, + ) def __call__(self, trajs, rewards, flat_rewards, cond_info): # locally (in-process) accumulate flat rewards to build a better pareto estimate @@ -223,3 +230,25 @@ def finalize(self): top_ks = [np.mean(sorted(i)[-self.k :]) for i in repeats.values()] assert len(top_ks) == self.num_preferences # Make sure we got all of them? return top_ks + +class RewardStats: + """ + Calculate percentiles of the reward + """ + def __init__(self, idx=None): + if idx is None: + idx = [1.0, 0.75, 0.5, 0.25, 0] + self.idx = idx + + def __call__(self, trajs, rewards, flat_rewards, cond_info): + x = np.sort(flat_rewards.numpy(), axis=0) + ret = {} + y = np.sort(rewards.numpy()) + for i, idx in enumerate(self.idx): + f = max(min(math.floor(x.shape[0] * idx), x.shape[0] - 1), 0) + for j in range(x.shape[1]): + ret[f"fr_{j}_{idx:.2f}%"] = x[f, j] + ret[f"r_{idx:.2f}%"] = y[f] + + ret["sample_len"] = sum([len(i["traj"]) for i in trajs]) / len(trajs) + return ret diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 9687087e..88be95ea 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -62,6 +62,16 @@ def __len__(self): return self.out_queue.get() +class KeepAlive: + def __init__(self, flag): + self.flag = flag + + def close(self): + self.flag.set() + + def __del__(self): + self.close() + class MPObjectProxy: """This class maintains a reference to some object and creates a `placeholder` attribute which can be safely passed to @@ -103,12 +113,10 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.device = torch.device("cpu") self.cuda_types = (torch.Tensor,) + cast_types self.stop = threading.Event() + self.keepalive = KeepAlive(self.stop) self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() - def __del__(self): - self.stop.set() - def encode(self, m): if self.pickle_messages: return pickle.dumps(m) @@ -123,14 +131,18 @@ def to_cpu(self, i): return i.detach().to(torch.device("cpu")) if isinstance(i, self.cuda_types) else i def run(self): - while not self.stop.is_set(): + timeouts = 0 + + while not self.stop.is_set() or timeouts < 500: for qi, q in enumerate(self.in_queues): try: r = self.decode(q.get(True, 1e-5)) except queue.Empty: + timeouts += 1 continue except ConnectionError: break + timeouts = 0 attr, args, kwargs = r f = getattr(self.obj, attr) args = [i.to(self.device) if isinstance(i, self.cuda_types) else i for i in args] @@ -190,4 +202,5 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - return MPObjectProxy(obj, num_workers, cast_types, pickle_messages).placeholder + x = MPObjectProxy(obj, num_workers, cast_types, pickle_messages) + return x.placeholder, x.keepalive diff --git a/tests/test_subtb.py b/tests/test_subtb.py index c4841689..898eec42 100644 --- a/tests/test_subtb.py +++ b/tests/test_subtb.py @@ -1,8 +1,11 @@ from functools import reduce +import numpy as np +import networkx as nx import torch -from gflownet.algo.trajectory_balance import subTB +from gflownet.algo.trajectory_balance import log_mixture, subTB +from gflownet.envs.frag_mol_env import NCounter def subTB_ref(P_F, P_B, F): @@ -27,3 +30,54 @@ def test_subTB(): P_B = torch.randint(1, 10, (T,)) F = torch.randint(1, 10, (T + 1,)) assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all() +def test_log_mixture(): + for a in [0, 0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1]: + x = -abs(torch.randn(10)) + y = -abs(torch.randn(10)) + a = 0.9 + approx = a * torch.exp(x) + (1 - a) * torch.exp(y) + assert (log_mixture(x, y, a).exp() - approx).max() < 1e-3 + + +def test_n(): + n = NCounter() + x = 0 + for i in range(1, 10): + x += np.log(i) + assert np.isclose(n.lfac(i), x) + + assert np.isclose(n.lcomb(5, 2), np.log(10)) + + +def test_g1(): + n = NCounter() + g = nx.Graph() + for i in range(3): + g.add_node(i) + g.add_edge(0, 1) + g.add_edge(1, 2) + rg = n.root_tree(g, 0) + assert n.f(rg, 0) == 0 + rg = n.root_tree(g, 2) + assert n.f(rg, 2) == 0 + rg = n.root_tree(g, 1) + assert np.isclose(n.f(rg, 1), np.log(2)) + + assert np.isclose(n(g), np.log(4)) + + +def test_g(): + n = NCounter() + g = nx.Graph() + for i in range(3): + g.add_node(i) + g.add_edge(0, 1) + g.add_edge(1, 2, weight=2) + rg = n.root_tree(g, 0) + assert n.f(rg, 0) == 0 + rg = n.root_tree(g, 2) + assert np.isclose(n.f(rg, 2), np.log(2)) + rg = n.root_tree(g, 1) + assert np.isclose(n.f(rg, 1), np.log(3)) + + assert np.isclose(n(g), np.log(6)) From cc13ae65de6acbc534fc5c6b21f1bad5962265ed Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Tue, 6 Feb 2024 02:11:35 -0500 Subject: [PATCH 02/33] fix style --- src/gflownet/algo/config.py | 3 +++ src/gflownet/data/qm9.py | 8 +++--- src/gflownet/data/sampling_iterator.py | 2 +- src/gflownet/envs/frag_mol_env.py | 10 +++++--- src/gflownet/envs/graph_building_env.py | 10 +++++--- src/gflownet/models/bengio2021flow.py | 2 ++ src/gflownet/online_trainer.py | 11 ++++++--- src/gflownet/tasks/qm9_moo.py | 27 +++++++-------------- src/gflownet/trainer.py | 18 +++++++++++--- src/gflownet/utils/focus_model.py | 12 ++++----- src/gflownet/utils/metrics.py | 3 +++ src/gflownet/utils/multiobjective_hooks.py | 3 +++ src/gflownet/utils/multiprocessing_proxy.py | 1 + tests/test_subtb.py | 11 ++------- 14 files changed, 68 insertions(+), 53 deletions(-) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index f2bf178a..26839ae3 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -109,6 +109,8 @@ class AlgoConfig: Idem but for validation, and `self.test_data`. train_random_action_prob : float The probability of taking a random action during training + train_det_after: Optional[int] + Do not take random actions after this number of steps valid_random_action_prob : float The probability of taking a random action during validation valid_sample_cond_info : bool @@ -126,6 +128,7 @@ class AlgoConfig: offline_ratio: float = 0.5 valid_offline_ratio: float = 1 train_random_action_prob: float = 0.0 + train_det_after: Optional[int] = None valid_random_action_prob: float = 0.0 valid_sample_cond_info: bool = True sampling_tau: float = 0.0 diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index d9a1e217..e1dcaa5f 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -3,12 +3,12 @@ import numpy as np import pandas as pd import rdkit.Chem as Chem -from rdkit.Chem import QED, Descriptors -from rdkit.Chem.rdchem import Mol as RDMol -from gflownet.utils import metrics, sascore import torch +from rdkit.Chem import QED, Descriptors from torch.utils.data import Dataset +from gflownet.utils import sascore + class QM9Dataset(Dataset): def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): @@ -25,7 +25,7 @@ def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], spl else: self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] self.mol_to_graph = lambda x: x - + def setup(self, task, ctx): self.mol_to_graph = ctx.mol_to_graph diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index c19bfe00..3a1985ba 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -189,7 +189,7 @@ def __iter__(self): flat_rewards = ( list(self.task.flat_reward_transform(torch.stack(flat_rewards))) if len(flat_rewards) else [] ) - + trajs = self.algo.create_training_data_from_graphs( graphs, self.model, cond_info["encoding"][:num_offline], 0 ) diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 293b733d..bab9506b 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -1,12 +1,12 @@ from collections import defaultdict from math import log -from typing import Any, List, Tuple +from typing import List, Tuple +import networkx as nx import numpy as np import rdkit.Chem as Chem import torch import torch_geometric.data as gd -import networkx as nx from scipy import special from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionType, GraphBuildingEnvContext @@ -368,10 +368,11 @@ def log_n(self, g: Graph) -> int: class NCounter: """ - See Appendix D of "Maximum entropy GFlowNets with soft Q-learning" Mohammadpour et al 2024 (https://arxiv.org/abs/2312.14331) for a proof. Dynamic program to calculate the number of trajectories to a state. - + See Appendix D of "Maximum entropy GFlowNets with soft Q-learning" + by Mohammadpour et al 2024 (https://arxiv.org/abs/2312.14331) for a proof. """ + def __init__(self): # Hold the log factorial self.cache = [0.0, 0.0] @@ -438,6 +439,7 @@ def __call__(self, g): return special.logsumexp(acc) + def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, numiters=None): if numiters is None: numiters = [0] diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 4381bd85..f1b12d93 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -536,9 +536,12 @@ def __init__( # This generalizes to edges and non-edges. # Append '_batch' to keys except for 'x', since TG has a special case (done by default for 'x') self.batch = [ - getattr(graphs, f"{k}_batch" if k != "x" else "batch") if k is not None - # None signals a global logit rather than a per-instance logit - else torch.arange(graphs.num_graphs, device=dev) + ( + getattr(graphs, f"{k}_batch" if k != "x" else "batch") + if k is not None + # None signals a global logit rather than a per-instance logit + else torch.arange(graphs.num_graphs, device=dev) + ) for k in keys ] # This is the cumulative sum (prefixed by 0) of N[i]s @@ -957,6 +960,7 @@ def object_to_log_repr(self, g: Graph) -> str: return json.dumps( [[(i, g.nodes[i]) for i in g.nodes], [(e, g.edges[e]) for e in g.edges]], separators=(",", ":") ) + def has_n(self) -> bool: return False diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index dcd9894f..77c9be32 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -7,6 +7,7 @@ In particular, this model class allows us to compare to the same target proxy used in that paper (sEH binding affinity prediction). """ + import gzip import os import pickle # nosec @@ -129,6 +130,7 @@ ["C1CCNCC1", [1, 0]], ] + class MPNNet(nn.Module): def __init__( self, diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index b7a811ac..03509521 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -16,13 +16,15 @@ from .trainer import GFNTrainer + def model_grad_norm(model): x = 0 - for i in self.model.parameters(): + for i in model.parameters(): if i.grad is not None: x += (i.grad * i.grad).sum() return torch.sqrt(x) + class StandardOnlineTrainer(GFNTrainer): def setup_model(self): self.model = GraphTransformerGFN( @@ -109,13 +111,13 @@ def setup(self): print(yaml) with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f: f.write(yaml) - + def step(self, loss: Tensor): loss.backward() with torch.no_grad(): - g0 = model_grad_norm(model) + g0 = model_grad_norm(self.model) self.clip_grad_callback(self.model.parameters()) - g1 = model_grad_norm(model) + g1 = model_grad_norm(self.model) self.opt.step() self.opt.zero_grad() self.opt_Z.step() @@ -125,3 +127,4 @@ def step(self, loss: Tensor): if self.sampling_tau > 0: for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + return {"grad_norm": g0, "grad_norm_clip": g1} diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 849f8baf..b51ca464 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -1,34 +1,24 @@ -import os import pathlib -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd +from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol -from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset -from rdkit.Chem import QED, Descriptors -from gflownet.utils import metrics, sascore -from gflownet.algo.envelope_q_learning import EnvelopeQLearning -from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce import gflownet.models.mxmnet as mxmnet from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.online_trainer import StandardOnlineTrainer -from gflownet.trainer import FlatRewards, GFNTask, RewardScalar -from gflownet.utils import metrics -from gflownet.utils.conditioning import ( - FocusRegionConditional, - MultiObjectiveWeightedPreferences, - TemperatureConditional, -) from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer -from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.trainer import FlatRewards, RewardScalar +from gflownet.utils import metrics, sascore +from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences +from gflownet.utils.multiobjective_hooks import TopKHook def safe(f, x, default): @@ -215,7 +205,7 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat return RewardScalar(tempered_reward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - graphs = [mxmnet.mol2graph(i) for i in mols] + graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] assert len(graphs) == len(mols) is_valid = torch.tensor([i is not None for i in graphs]).bool() valid_graphs = [g for g in graphs if g is not None] @@ -229,7 +219,8 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: if obj == "gap": batch = gd.Batch.from_data_list(valid_graphs) batch.to(self.device) - preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] + preds = self.models["mxmnet_gap"](batch) + preds = preds.reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 elif obj == "qed": preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index a6d4710c..2a941a6d 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -1,6 +1,9 @@ +import gc import os import pathlib -from typing import Any, Callable, Dict, List, NewType, Optional, Tuple +import shutil +import time +from typing import Any, Callable, Dict, List, NewType, Optional, Protocol, Tuple import numpy as np import torch @@ -18,6 +21,7 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger +from gflownet.utils.multiobjective_hooks import RewardStats from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from .config import Config @@ -35,6 +39,7 @@ class GFNAlgorithm: def step(self): self.updates += 1 + def compute_batch_losses( self, model: nn.Module, batch: gd.Batch, num_bootstrap: Optional[int] = 0 ) -> Tuple[Tensor, Dict[str, Tensor]]: @@ -94,6 +99,11 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: raise NotImplementedError() +class Closable(Protocol): + def close(self): + pass + + class GFNTrainer: def __init__(self, hps: Dict[str, Any], print=True): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. @@ -106,7 +116,7 @@ def __init__(self, hps: Dict[str, Any], print=True): The torch device of the main worker. """ self.print = print - self.to_close = [] + self.to_close: List[Closable] = [] # self.setup should at least set these up: self.training_data: Dataset self.test_data: Dataset @@ -409,14 +419,14 @@ def _save_state(self, it): state, fd, ) - shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") + shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") def log(self, info, index, key): if not hasattr(self, "_summary_writer"): self._summary_writer = torch.utils.tensorboard.SummaryWriter(self.cfg.log_dir) for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) - + def close(self): while len(self.to_close) > 0: try: diff --git a/src/gflownet/utils/focus_model.py b/src/gflownet/utils/focus_model.py index 14bf6c71..70cb4950 100644 --- a/src/gflownet/utils/focus_model.py +++ b/src/gflownet/utils/focus_model.py @@ -89,12 +89,12 @@ def sample_focus_directions(self, n: int): """ sampling_likelihoods = torch.zeros_like(self.focus_dir_count).float().to(self.device) sampling_likelihoods[self.focus_dir_count == 0] = self.feasible_flow - sampling_likelihoods[ - torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0) - ] = self.feasible_flow - sampling_likelihoods[ - torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0) - ] = self.infeasible_flow + sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count > 0)] = ( + self.feasible_flow + ) + sampling_likelihoods[torch.logical_and(self.focus_dir_count > 0, self.focus_dir_population_count == 0)] = ( + self.infeasible_flow + ) focus_dir_indices = torch.multinomial(sampling_likelihoods, n, replacement=True) return self.focus_dir_dataset[focus_dir_indices].to("cpu") diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index 5638da28..d6ffc7dd 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -536,6 +536,8 @@ def chunkedsim(thresh, fp, mode_fps, delta=16): return False s = e return True + + # Should be calculated per preference def compute_diverse_top_k(smiles, rewards, k, thresh=0.7): # mols is a list of (reward, mol) @@ -578,6 +580,7 @@ def get_topk(rewards, k): mean_topk = torch.mean(topk_rewards.mean(-1)) return mean_topk + def top_k_diversity(fps, r, K): x = [] for i in np.argsort(r)[::-1]: diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 95c72f1f..2113b1a7 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -1,3 +1,4 @@ +import math import pathlib import queue import threading @@ -231,10 +232,12 @@ def finalize(self): assert len(top_ks) == self.num_preferences # Make sure we got all of them? return top_ks + class RewardStats: """ Calculate percentiles of the reward """ + def __init__(self, idx=None): if idx is None: idx = [1.0, 0.75, 0.5, 0.25, 0] diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 88be95ea..d9be2558 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -72,6 +72,7 @@ def close(self): def __del__(self): self.close() + class MPObjectProxy: """This class maintains a reference to some object and creates a `placeholder` attribute which can be safely passed to diff --git a/tests/test_subtb.py b/tests/test_subtb.py index 898eec42..d89ea4bf 100644 --- a/tests/test_subtb.py +++ b/tests/test_subtb.py @@ -1,10 +1,10 @@ from functools import reduce -import numpy as np import networkx as nx +import numpy as np import torch -from gflownet.algo.trajectory_balance import log_mixture, subTB +from gflownet.algo.trajectory_balance import subTB from gflownet.envs.frag_mol_env import NCounter @@ -30,13 +30,6 @@ def test_subTB(): P_B = torch.randint(1, 10, (T,)) F = torch.randint(1, 10, (T + 1,)) assert (subTB(F, P_F - P_B) == subTB_ref(P_F, P_B, F)).all() -def test_log_mixture(): - for a in [0, 0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 1]: - x = -abs(torch.randn(10)) - y = -abs(torch.randn(10)) - a = 0.9 - approx = a * torch.exp(x) + (1 - a) * torch.exp(y) - assert (log_mixture(x, y, a).exp() - approx).max() < 1e-3 def test_n(): From a4d467b080f653207b21f77a975eb5629ffe1812 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Tue, 6 Feb 2024 02:27:57 -0500 Subject: [PATCH 03/33] fix style pt.2 --- src/gflownet/data/qm9.py | 2 +- src/gflownet/models/bengio2021flow.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index e1dcaa5f..83373411 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -53,7 +53,7 @@ def load_tar(xyz_file): f = tarfile.TarFile(xyz_file, "r") all_mols = [] for pt in f: - pt = f.extractfile(pt) # type: ignore3 + pt = f.extractfile(pt) # type: ignore data = pt.read().decode().splitlines() # type: ignore all_mols.append(data[-2].split()[:1] + list(map(float, data[1].split()[2:]))) df = pd.DataFrame(all_mols, columns=["SMILES"] + labels) diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index 77c9be32..ae71d74d 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -106,9 +106,9 @@ ["c1ncc2nc[nH]c2n1", [2, 6]], ] -""" -18 fragments from "Towards Understanding and Improving GFlowNet Training" by Shen et al. (https://arxiv.org/abs/2305.07170) -""" +# 18 fragments from "Towards Understanding and Improving GFlowNet Training" +# by Shen et al. (https://arxiv.org/abs/2305.07170) + FRAGMENTS_18 = [ ["CO", [1, 0]], ["O=c1[nH]cnc2[nH]cnc12", [3, 6]], From 21375ed8b57362bc434a8f1bf03aa79675c2109e Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 00:54:52 -0500 Subject: [PATCH 04/33] style --- src/gflownet/tasks/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index d00a13e2..c64cc4f4 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List +from typing import List, Optional, Tuple @dataclass From d18d159f4aa252b5fb6a1f431f8b946bd0ea363c Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 00:55:48 -0500 Subject: [PATCH 05/33] rename print=True to print_hps --- src/gflownet/online_trainer.py | 2 +- src/gflownet/trainer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 03509521..f36b8211 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -106,7 +106,7 @@ def setup(self): yaml = OmegaConf.to_yaml(self.cfg) os.makedirs(self.cfg.log_dir, exist_ok=True) - if self.print: + if self.print_hps: print("\n\nHyperparameters:\n") print(yaml) with open(pathlib.Path(self.cfg.log_dir) / "hps.yaml", "w", encoding="utf8") as f: diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 2a941a6d..d1e075c2 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -105,7 +105,7 @@ def close(self): class GFNTrainer: - def __init__(self, hps: Dict[str, Any], print=True): + def __init__(self, hps: Dict[str, Any], print_hps=True): """A GFlowNet trainer. Contains the main training loop in `run` and should be subclassed. Parameters @@ -115,7 +115,7 @@ def __init__(self, hps: Dict[str, Any], print=True): device: torch.device The torch device of the main worker. """ - self.print = print + self.print_hps = print_hps self.to_close: List[Closable] = [] # self.setup should at least set these up: self.training_data: Dataset From 82de5798df897ccf2716c7341a1e4f75d716ce29 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:02:27 -0500 Subject: [PATCH 06/33] fix qm9 problems --- src/gflownet/data/qm9.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index 83373411..46c44ad2 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -1,3 +1,4 @@ +import sys import tarfile import numpy as np @@ -15,9 +16,11 @@ def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], spl if h5_file is not None: self.df = pd.HDFStore(h5_file, "r")["df"] elif xyz_file is not None: - self.load_tar() + self.df = load_tar(xyz_file) + else: + raise ValueError("Either h5_file or xyz_file must be provided") rng = np.random.default_rng(split_seed) - idcs = np.arange(len(self.df)) # TODO: error if there is no h5_file provided. Should h5 be required + idcs = np.arange(len(self.df)) rng.shuffle(idcs) self.targets = targets if train: @@ -35,9 +38,6 @@ def get_stats(self, target=None, percentile=0.95): y = self.df[target] return y.min(), y.max(), np.sort(y)[int(y.shape[0] * percentile)] - def load_tar(self, xyz_file): - self.df = load_tar(xyz_file) - def __len__(self): return len(self.idcs) @@ -64,10 +64,17 @@ def load_tar(xyz_file): return df -def convert_h5(): +def convert_h5(xyz_file="qm9.xyz.tar", h5_file="qm9.h5"): + """ + Convert `xyz_file` and dump it into `h5_file` + """ # File obtained from # https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904 # (from http://quantum-machine.org/datasets/) - df = load_tar("qm9.xyz.tar") - with pd.HDFStore("qm9.h5", "w") as store: + df = load_tar(xyz_file) + with pd.HDFStore(h5_file, "w") as store: store["df"] = df + + +if __name__ == "__main__": + convert_h5(*sys.argv[1:]) From 1c34443d795898db3dc5e4206ae86d602cc9cebd Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:04:04 -0500 Subject: [PATCH 07/33] docu --- src/gflownet/tasks/qm9_moo.py | 8 ++++---- src/gflownet/tasks/seh_frag_moo.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index b51ca464..2da477e6 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -30,10 +30,10 @@ def safe(f, x, default): class QM9GapMOOTask(QM9GapTask): """Sets up a multiobjective task where the rewards are (functions of): - - the the binding energy of a molecule to Soluble Epoxide Hydrolases. - - its QED - - its synthetic accessibility - - its molecular weight + - the homo-lumo gap, + - its QED, + - its synthetic accessibility, + - and its molecular weight. The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. """ diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 73432a08..6da5b833 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -27,10 +27,10 @@ class SEHMOOTask(SEHTask): """Sets up a multiobjective task where the rewards are (functions of): - - the the binding energy of a molecule to Soluble Epoxide Hydrolases. - - its QED - - its synthetic accessibility - - its molecular weight + - the binding energy of a molecule to Soluble Epoxide Hydrolases, + - its QED, + - its synthetic accessibility, + - and its molecular weight. The proxy is pretrained, and obtained from the original GFlowNet paper, see `gflownet.models.bengio2021flow`. """ From ce5a5655fb506fab8fe25f17d3452d4660ef0e71 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:07:09 -0500 Subject: [PATCH 08/33] rename chunked sim --- src/gflownet/utils/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gflownet/utils/metrics.py b/src/gflownet/utils/metrics.py index d6ffc7dd..7f5dda4c 100644 --- a/src/gflownet/utils/metrics.py +++ b/src/gflownet/utils/metrics.py @@ -522,9 +522,9 @@ def inv_transform(self, arr): return self.scale * arr + self.loc -def chunkedsim(thresh, fp, mode_fps, delta=16): +def all_are_tanimoto_different(thresh, fp, mode_fps, delta=16): """ - Equivalent to `all(DataStructs.BulkTanimotoSimilarity(fp, mode_fps) < thresh)` + Equivalent to `all(DataStructs.BulkTanimotoSimilarity(fp, mode_fps) < thresh)` but much faster. """ assert delta > 0 s = 0 From b149ef2f6d9221ebfdc8f188b378e5a5a0b40e74 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:51:26 -0500 Subject: [PATCH 09/33] move traj len out of reward percentilehook --- src/gflownet/utils/multiobjective_hooks.py | 35 +++++++++++++++------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 79c4f839..623af685 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -237,25 +237,40 @@ def finalize(self): return top_ks -class RewardStats: +class RewardPercentilesHook: """ - Calculate percentiles of the reward + Calculate percentiles of the reward. + + Parameters + ---------- + idx: List[float] + The percentiles to calculate. Should be in the range [0, 1]. + Default: [1.0, 0.75, 0.5, 0.25, 0] """ - def __init__(self, idx=None): - if idx is None: - idx = [1.0, 0.75, 0.5, 0.25, 0] - self.idx = idx + def __init__(self, percentiles=None): + if percentiles is None: + percentiles = [1.0, 0.75, 0.5, 0.25, 0] + self.percentiles = percentiles def __call__(self, trajs, rewards, flat_rewards, cond_info): x = np.sort(flat_rewards.numpy(), axis=0) ret = {} y = np.sort(rewards.numpy()) - for i, idx in enumerate(self.idx): - f = max(min(math.floor(x.shape[0] * idx), x.shape[0] - 1), 0) + for p in self.percentiles: + f = max(min(math.floor(x.shape[0] * p), x.shape[0] - 1), 0) for j in range(x.shape[1]): - ret[f"fr_{j}_{idx:.2f}%"] = x[f, j] - ret[f"r_{idx:.2f}%"] = y[f] + ret[f"percentile_flat_reward_{j}_{p:.2f}"] = x[f, j] + ret[f"percentile_reward_{p:.2f}%"] = y[f] + return ret + + +class TrajectoryLengthHook: + """ + Report the average trajectory length. + """ + def __call__(self, trajs, rewards, flat_rewards, cond_info): + ret = {} ret["sample_len"] = sum([len(i["traj"]) for i in trajs]) / len(trajs) return ret From 86d67dfa22ec70c422839eb09f36bbca20caaa24 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:53:41 -0500 Subject: [PATCH 10/33] remove ruamel --- pyproject.toml | 1 - src/gflownet/tasks/qm9/qm9.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 61aa4d63..d588c636 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,6 @@ dev = [ "types-pkg_resources", # Security pin "gitpython>=3.1.30", - "ruamel.yaml", ] [[project.authors]] diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 5458ec74..e818ea07 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -6,9 +6,9 @@ import torch.nn as nn import torch_geometric.data as gd from rdkit.Chem.rdchem import Mol as RDMol -from ruamel.yaml import YAML from torch import Tensor from torch.utils.data import Dataset +import yaml import gflownet.models.mxmnet as mxmnet from gflownet.config import Config @@ -154,10 +154,9 @@ def setup(self): def main(): """Example of how this model can be run.""" - yaml = YAML(typ="safe", pure=True) config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") with open(config_file, "r") as f: - hps = yaml.load(f) + hps = yaml.safe_load(f) trial = QM9GapTrainer(hps) trial.run() From d232bcb84ec18ae6d8b726032fb86302007fb636 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 01:57:51 -0500 Subject: [PATCH 11/33] tox --- src/gflownet/tasks/qm9/qm9.py | 2 +- src/gflownet/trainer.py | 3 +-- src/gflownet/utils/multiobjective_hooks.py | 3 +++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index e818ea07..a5d6ee2f 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -5,10 +5,10 @@ import torch import torch.nn as nn import torch_geometric.data as gd +import yaml from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset -import yaml import gflownet.models.mxmnet as mxmnet from gflownet.config import Config diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d1e075c2..d0a10016 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -21,7 +21,6 @@ from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch from gflownet.utils.misc import create_logger -from gflownet.utils.multiobjective_hooks import RewardStats from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from .config import Config @@ -145,7 +144,7 @@ def __init__(self, hps: Dict[str, Any], print_hps=True): # Print the loss every `self.print_every` iterations self.print_every = self.cfg.print_every # These hooks allow us to compute extra quantities when sampling data - self.sampling_hooks: List[Callable] = [RewardStats()] + self.sampling_hooks: List[Callable] = [] self.valid_sampling_hooks: List[Callable] = [] # Will check if parameters are finite at every iteration (can be costly) self._validate_parameters = False diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 623af685..29874e8a 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -270,6 +270,9 @@ class TrajectoryLengthHook: Report the average trajectory length. """ + def __init__(self) -> None: + pass + def __call__(self, trajs, rewards, flat_rewards, cond_info): ret = {} ret["sample_len"] = sum([len(i["traj"]) for i in trajs]) / len(trajs) From 1bf872469e7346f7da103d9f5a831066bdeec8d5 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 02:01:35 -0500 Subject: [PATCH 12/33] add flag to store all checkpoints --- src/gflownet/config.py | 3 +++ src/gflownet/trainer.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gflownet/config.py b/src/gflownet/config.py index be4fa879..6941e7a7 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -60,6 +60,8 @@ class Config: The number of training steps after which to validate the model checkpoint_every : Optional[int] The number of training steps after which to checkpoint the model + store_all_checkpoints : bool + Whether to store all checkpoints or only the last one print_every : int The number of training steps after which to print the training loss start_at_step : int @@ -85,6 +87,7 @@ class Config: seed: int = 0 validate_every: int = 1000 checkpoint_every: Optional[int] = None + store_all_checkpoints: bool = False print_every: int = 100 start_at_step: int = 0 num_final_gen_steps: Optional[int] = None diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index d0a10016..bd748e2e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -418,7 +418,8 @@ def _save_state(self, it): state, fd, ) - shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") + if self.cfg.store_all_checkpoints: + shutil.copy(fn, pathlib.Path(self.cfg.log_dir) / f"model_state_{it}.pt") def log(self, info, index, key): if not hasattr(self, "_summary_writer"): From 2ea531faf9b1c9fe8da14198d6517491fea3ab72 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 02:04:14 -0500 Subject: [PATCH 13/33] fix moohook in seh_frag and remove it in qm9 --- src/gflownet/tasks/qm9_moo.py | 14 -------------- src/gflownet/tasks/seh_frag_moo.py | 1 + 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 2da477e6..57632063 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -259,20 +259,6 @@ def setup_task(self): def setup(self): super().setup() - # self.sampling_hooks.append( - # MultiObjectiveStatsHook( - # 256, - # self.cfg.log_dir, - # compute_igd=True, - # compute_hvi=False, - # compute_hsri=False, - # compute_normed=False, - # compute_pc_entropy=True, - # compute_focus_accuracy=True if self.cfg.task.qm9_moo.focus_type is not None else False, - # focus_cosim=self.cfg.task.qm9_moo.focus_cosim, - # ) - # ) - # self.to_close.append(self.sampling_hooks[-1].keep_alive) # instantiate preference and focus conditioning vectors for validation tcfg = self.cfg.task.qm9_moo diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 6da5b833..893a4b0c 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -271,6 +271,7 @@ def setup(self): focus_cosim=self.cfg.cond.focus_region.focus_cosim, ) ) + self.to_close.append(self.sampling_hooks[-1].keep_alive) # instantiate preference and focus conditioning vectors for validation n_obj = len(self.cfg.task.seh_moo.objectives) From 0760b87514316cababbc1e81a187830c43b4a3e9 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 02:05:25 -0500 Subject: [PATCH 14/33] add comment about the graceful termination of moostats --- src/gflownet/utils/multiobjective_hooks.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 29874e8a..6dda97ee 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -15,6 +15,10 @@ class MultiObjectiveStatsHook: + """ + This hook is multithreaded and the keep_alive object needs to be closed for graceful termination. + """ + def __init__( self, num_to_keep: int, From 6dcce365d04951e75f4b44da408fbfed39048e37 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 9 Feb 2024 23:54:12 -0500 Subject: [PATCH 15/33] add a flag for predicting n --- src/gflownet/algo/config.py | 3 +++ src/gflownet/online_trainer.py | 1 + 2 files changed, 4 insertions(+) diff --git a/src/gflownet/algo/config.py b/src/gflownet/algo/config.py index 26839ae3..6184bdfc 100644 --- a/src/gflownet/algo/config.py +++ b/src/gflownet/algo/config.py @@ -29,6 +29,8 @@ class TBConfig: Whether to correct for idempotent actions do_parameterize_p_b : bool Whether to parameterize the P_B distribution (otherwise it is uniform) + do_predict_n : bool + Whether to predict the number of paths in the graph do_length_normalize : bool Whether to normalize the loss by the length of the trajectory subtb_max_len : int @@ -45,6 +47,7 @@ class TBConfig: variant: TBVariant = TBVariant.TB do_correct_idempotent: bool = False do_parameterize_p_b: bool = False + do_predict_n: bool = False do_sample_p_b: bool = False do_length_normalize: bool = False subtb_max_len: int = 128 diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index f36b8211..81fd7ae7 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -31,6 +31,7 @@ def setup_model(self): self.ctx, self.cfg, do_bck=self.cfg.algo.tb.do_parameterize_p_b, + num_graph_out=self.cfg.algo.tb.do_predict_n + 1, ) def setup_algo(self): From d9acdb65a07822923ed340a2724557def8a82d0c Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sat, 10 Feb 2024 23:42:33 -0500 Subject: [PATCH 16/33] REMOVE USELESS QM9 THING --- src/gflownet/tasks/qm9/qm9.py | 14 -------------- src/gflownet/tasks/qm9/qm9.yaml | 10 ---------- 2 files changed, 24 deletions(-) delete mode 100644 src/gflownet/tasks/qm9/qm9.yaml diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index a5d6ee2f..9541334d 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch_geometric.data as gd -import yaml from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset @@ -150,16 +149,3 @@ def setup(self): super().setup() self.training_data.setup(self.task, self.ctx) self.test_data.setup(self.task, self.ctx) - - -def main(): - """Example of how this model can be run.""" - config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "qm9.yaml") - with open(config_file, "r") as f: - hps = yaml.safe_load(f) - trial = QM9GapTrainer(hps) - trial.run() - - -if __name__ == "__main__": - main() diff --git a/src/gflownet/tasks/qm9/qm9.yaml b/src/gflownet/tasks/qm9/qm9.yaml deleted file mode 100644 index 19701fac..00000000 --- a/src/gflownet/tasks/qm9/qm9.yaml +++ /dev/null @@ -1,10 +0,0 @@ -opt: - lr_decay: 10000 -task: - qm9: - h5_path: /rxrx/data/chem/qm9/qm9.h5 - model_path: /rxrx/data/chem/qm9/mxmnet_gap_model.pt -num_training_steps: 100000 -validate_every: 100 -log_dir: ./logs/debug_qm9 -num_workers: 0 From 4f350ec44ad85d9533f3efe30ab3d5fa38e5d19e Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sat, 10 Feb 2024 23:44:15 -0500 Subject: [PATCH 17/33] fix typo --- src/gflownet/tasks/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index c64cc4f4..1cbabad9 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -18,7 +18,7 @@ class SEHMOOTaskConfig: n_valid_repeats : int The number of times to repeat the valid cond_info tensors objectives : List[str] - The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. + The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "mw"]. """ n_valid: int = 15 From 2a20d6a7a014b952a9ba0c7f9c6ae74d47f9ace4 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 00:04:09 -0500 Subject: [PATCH 18/33] upgrade qm9 --- src/gflownet/tasks/config.py | 22 ++-- src/gflownet/tasks/qm9_moo.py | 172 ++++++++++++++--------------- src/gflownet/tasks/seh_frag_moo.py | 58 +++++----- 3 files changed, 123 insertions(+), 129 deletions(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 1cbabad9..5d56f038 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -34,15 +34,19 @@ class QM9TaskConfig: @dataclass class QM9MOOTaskConfig: - use_steer_thermometer: bool = False - preference_type: Optional[str] = "dirichlet" - focus_type: Optional[str] = None - focus_dirs_listed: Optional[List[List[float]]] = None - focus_cosim: float = 0.0 - focus_limit_coef: float = 1.0 - focus_model_training_limits: Optional[Tuple[int, int]] = None - focus_model_state_space_res: Optional[int] = None - max_train_it: Optional[int] = None + """Config for the QM9MOOTask + + Attributes + ---------- + n_valid : int + The number of valid cond_info tensors to sample + n_valid_repeats : int + The number of times to repeat the valid cond_info tensors + objectives : List[str] + The objectives to use for the multi-objective optimization. Should be a subset of ["gap", "qed", "sa", "mw"]. + While "mw" can be used, it is not recommended as the molecules are already small. + """ + n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 57632063..d2266c43 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -10,6 +10,8 @@ from torch import Tensor from torch.utils.data import Dataset +from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL +from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce import gflownet.models.mxmnet as mxmnet from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset @@ -18,14 +20,9 @@ from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences -from gflownet.utils.multiobjective_hooks import TopKHook - - -def safe(f, x, default): - try: - return f(x) - except Exception: - return default +from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook +from gflownet.utils.transforms import to_logreward +from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, safe class QM9GapMOOTask(QM9GapTask): @@ -65,50 +62,11 @@ def __init__( ) assert set(self.objectives) <= {"gap", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) - def flat_reward_transform(self, y: Tensor) -> FlatRewards: - assert y.shape[-1] == len(self.objectives) - if len(y.shape) == 1: - y = y[None, :] - assert len(y.shape) == 2 - - flat_r = [] - for i, obj in enumerate(self.objectives): - preds = y[:, i] - if obj == "gap": - preds = super().flat_reward_transform(preds) - elif obj == "qed": - pass - elif obj == "sa": - preds = (10 - preds) / 9 # Turn into a [0-1] reward - elif obj == "mw": - preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 - else: - raise ValueError(f"{obj} not known") - flat_r.append(preds) - return FlatRewards(torch.stack(flat_r, dim=1)) + def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + return FlatRewards(torch.as_tensor(y)) def inverse_flat_reward_transform(self, rp): - assert rp.shape[-1] == len(self.objectives) - if len(rp.shape) == 1: - rp = rp[None, :] - assert len(rp.shape) == 2 - - flat_r = [] - for i, obj in enumerate(self.objectives): - preds = rp[:, i] - if obj == "qed": - preds = super().inverse_flat_reward_transform(preds) - elif obj == "qed": - pass - elif obj == "sa": - preds = 10 - 9 * preds - elif obj == "mw": - preds = 300 - 700 * (preds - 1) - else: - raise ValueError(f"{obj} not known") - flat_r.append(preds) - - return FlatRewards(torch.stack(flat_r, dim=1)) + return rp def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) @@ -189,20 +147,26 @@ def relabel_condinfo_and_logrewards( return cond_info, log_rewards def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + """ + Compute the logreward from the flat_reward and the conditional information + """ if isinstance(flat_reward, list): if isinstance(flat_reward[0], Tensor): flat_reward = torch.stack(flat_reward) else: flat_reward = torch.tensor(flat_reward) - scalarized_reward = self.pref_cond.transform(cond_info, flat_reward) - focused_reward = ( - self.focus_cond.transform(cond_info, flat_reward, scalarized_reward) + scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) + scalarized_logrewards = to_logreward(scalarized_rewards) + focused_logreward = ( + self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) if self.focus_cond is not None - else scalarized_reward + else scalarized_logrewards ) - tempered_reward = self.temperature_conditional.transform(cond_info, focused_reward) - return RewardScalar(tempered_reward) + tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) + clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) + + return RewardScalar(clamped_logreward) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] @@ -222,19 +186,23 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: preds = self.models["mxmnet_gap"](batch) preds = preds.reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 + preds = super().flat_reward_transform(preds) elif obj == "qed": preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) elif obj == "sa": preds = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) + preds = (10 - preds) / 9 # Turn into a [0-1] reward elif obj == "mw": preds = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) + preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 else: - raise ValueError(f"{obj} not known") + raise ValueError(f"MOO objective {obj} not known") + assert len(preds) == len(valid_graphs), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" flat_r.append(preds) - assert len(preds) == len(valid_graphs), f"{len(preds)} != {len(valid_graphs)} for obj {obj}" + flat_rewards = torch.stack(flat_r, dim=1) - return self.flat_reward_transform(flat_rewards), is_valid + return FlatRewards(flat_rewards), is_valid class QM9MOOTrainer(QM9GapTrainer): @@ -249,6 +217,14 @@ def set_default_hps(self, cfg: Config): cfg.algo.valid_sample_cond_info = False cfg.algo.valid_offline_ratio = 1 + def setup_algo(self): + algo = self.cfg.algo.method + if algo == "MOREINFORCE": + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + elif algo == "MOQL": + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + else: + super().setup_algo() def setup_task(self): self.task = QM9GapMOOTask( dataset=self.training_data, @@ -257,42 +233,68 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) + def setup_env_context(self): + self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) + + def setup_model(self): + if self.cfg.algo.method == "MOQL": + self.model = GraphTransformerFragEnvelopeQL( + self.ctx, + num_emb=self.cfg.model.num_emb, + num_layers=self.cfg.model.num_layers, + num_heads=self.cfg.model.graph_transformer.num_heads, + num_objectives=len(self.cfg.task.seh_moo.objectives), + ) + else: + super().setup_model() def setup(self): super().setup() + self.sampling_hooks.append( + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, + ) + ) + self.to_close.append(self.sampling_hooks[-1].keep_alive) # instantiate preference and focus conditioning vectors for validation - tcfg = self.cfg.task.qm9_moo - n_obj = len(tcfg.objectives) + n_obj = len(self.cfg.task.seh_moo.objectives) + cond_cfg = self.cfg.cond # making sure hyperparameters for preferences and focus regions are consistent if not ( - tcfg.focus_type is None - or tcfg.focus_type == "centered" - or (type(tcfg.focus_type) is list and len(tcfg.focus_type) == 1) + cond_cfg.focus_region.focus_type is None + or cond_cfg.focus_region.focus_type == "centered" + or (isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) == 1) ): - assert tcfg.preference_type is None, ( - f"Cannot use preferences with multiple focus regions, here focus_type={tcfg.focus_type} " - f"and preference_type={tcfg.preference_type}" + assert cond_cfg.weighted_prefs.preference_type is None, ( + f"Cannot use preferences with multiple focus regions, " + f"here focus_type={cond_cfg.focus_region.focus_type} " + f"and preference_type={cond_cfg.weighted_prefs.preference_type }" ) - if type(tcfg.focus_type) is list and len(tcfg.focus_type) > 1: - n_valid = len(tcfg.focus_type) + if isinstance(cond_cfg.focus_region.focus_type, list) and len(cond_cfg.focus_region.focus_type) > 1: + n_valid = len(cond_cfg.focus_region.focus_type) else: - n_valid = tcfg.n_valid + n_valid = self.cfg.task.seh_moo.n_valid # preference vectors - if tcfg.preference_type is None: + if cond_cfg.weighted_prefs.preference_type is None: valid_preferences = np.ones((n_valid, n_obj)) - elif tcfg.preference_type == "dirichlet": + elif cond_cfg.weighted_prefs.preference_type == "dirichlet": valid_preferences = metrics.partition_hypersphere(d=n_obj, k=n_valid, normalisation="l1") - elif tcfg.preference_type == "seeded_single": + elif cond_cfg.weighted_prefs.preference_type == "seeded_single": seeded_prefs = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) valid_preferences = seeded_prefs[0].reshape((1, n_obj)) self.task.seeded_preference = valid_preferences[0] - elif tcfg.preference_type == "seeded_many": + elif cond_cfg.weighted_prefs.preference_type == "seeded_many": valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: - raise NotImplementedError(f"Unknown preference type {self.cfg.task.qm9_moo.preference_type}") + raise NotImplementedError(f"Unknown preference type {cond_cfg.weighted_prefs.preference_type}") # TODO: this was previously reported, would be nice to serialize it # hps["fixed_focus_dirs"] = ( @@ -313,8 +315,8 @@ def setup(self): else: valid_cond_vector = valid_preferences - self._top_k_hook = TopKHook(10, tcfg.n_valid_repeats, n_valid) - self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=tcfg.n_valid_repeats) + self._top_k_hook = TopKHook(10, self.cfg.task.seh_moo.n_valid_repeats, n_valid) + self.test_data = RepeatedCondInfoDataset(valid_cond_vector, repeat=self.cfg.task.seh_moo.n_valid_repeats) self.valid_sampling_hooks.append(self._top_k_hook) self.algo.task = self.task @@ -347,15 +349,9 @@ def _save_state(self, it): self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) + def run(self): + super().run() + for hook in self.sampling_hooks: + if hasattr(hook, "terminate"): + hook.terminate() -class RepeatedCondInfoDataset: - def __init__(self, cond_info_vectors, repeat): - self.cond_info_vectors = cond_info_vectors - self.repeat = repeat - - def __len__(self): - return len(self.cond_info_vectors) * self.repeat - - def __getitem__(self, idx): - assert 0 <= idx < len(self) - return torch.tensor(self.cond_info_vectors[int(idx // self.repeat)]) diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 893a4b0c..fc5f0db0 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -24,6 +24,11 @@ from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook from gflownet.utils.transforms import to_logreward +def safe(f, x, default): + try: + return f(x) + except Exception: + return default class SEHMOOTask(SEHTask): """Sets up a multiobjective task where the rewards are (functions of): @@ -180,37 +185,26 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: else: flat_r: List[Tensor] = [] - if "seh" in self.objectives: - batch = gd.Batch.from_data_list(valid_graphs) - batch.to(self.device) - seh_preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() / 8 - seh_preds[seh_preds.isnan()] = 0 - flat_r.append(seh_preds) - assert len(seh_preds) == len(valid_graphs), f"{len(seh_preds)} != {len(valid_graphs)}" - - def safe(f, x, default): - try: - return f(x) - except Exception: - return default - - if "qed" in self.objectives: - qeds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) - flat_r.append(qeds) - assert len(qeds) == len(valid_graphs), f"{len(qeds)} != {len(valid_graphs)}" - - if "sa" in self.objectives: - sas = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) - sas = (10 - sas) / 9 # Turn into a [0-1] reward - flat_r.append(sas) - assert len(sas) == len(valid_graphs), f"{len(sas)} != {len(valid_graphs)}" - - if "mw" in self.objectives: - molwts = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) - molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 - flat_r.append(molwts) - assert len(molwts) == len(valid_graphs), f"{len(molwts)} != {len(valid_graphs)}" - + for obj in self.objectives: + if obj == "seh": + batch = gd.Batch.from_data_list(valid_graphs) + batch.to(self.device) + preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() + preds[preds.isnan()] = 0 + preds = super().flat_reward_transform(preds) + elif obj == "qed": + preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) + elif obj == "sa": + preds = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) + preds = (10 - preds) / 9 # Turn into a [0-1] reward + elif obj == "mw": + preds = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) + preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + else: + raise ValueError(f"MOO objective {obj} not known") + assert len(preds) == len(valid_graphs), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" + flat_r.append(preds) + flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid @@ -306,7 +300,7 @@ def setup(self): elif cond_cfg.weighted_prefs.preference_type == "seeded_many": valid_preferences = np.random.default_rng(142857 + int(self.cfg.seed)).dirichlet([1] * n_obj, n_valid) else: - raise NotImplementedError(f"Unknown preference type {self.cfg.task.seh_moo.preference_type}") + raise NotImplementedError(f"Unknown preference type {cond_cfg.weighted_prefs.preference_type}") # TODO: this was previously reported, would be nice to serialize it # hps["fixed_focus_dirs"] = ( From 1ac2cd035435e3de55c769f6c85205f071e11dac Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 00:06:20 -0500 Subject: [PATCH 19/33] fmt --- src/gflownet/tasks/qm9_moo.py | 14 ++++++++------ src/gflownet/tasks/seh_frag_moo.py | 8 ++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index d2266c43..bd171adf 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import torch @@ -10,19 +10,19 @@ from torch import Tensor from torch.utils.data import Dataset +import gflownet.models.mxmnet as mxmnet from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce -import gflownet.models.mxmnet as mxmnet from gflownet.config import Config from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer +from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, safe from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook from gflownet.utils.transforms import to_logreward -from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, safe class QM9GapMOOTask(QM9GapTask): @@ -197,9 +197,10 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 else: raise ValueError(f"MOO objective {obj} not known") - assert len(preds) == len(valid_graphs), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" + assert len(preds) == len( + valid_graphs + ), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" flat_r.append(preds) - flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid @@ -225,6 +226,7 @@ def setup_algo(self): self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) else: super().setup_algo() + def setup_task(self): self.task = QM9GapMOOTask( dataset=self.training_data, @@ -247,6 +249,7 @@ def setup_model(self): ) else: super().setup_model() + def setup(self): super().setup() self.sampling_hooks.append( @@ -354,4 +357,3 @@ def run(self): for hook in self.sampling_hooks: if hasattr(hook, "terminate"): hook.terminate() - diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index fc5f0db0..63ae38bd 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -24,12 +24,14 @@ from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook from gflownet.utils.transforms import to_logreward + def safe(f, x, default): try: return f(x) except Exception: return default + class SEHMOOTask(SEHTask): """Sets up a multiobjective task where the rewards are (functions of): - the binding energy of a molecule to Soluble Epoxide Hydrolases, @@ -202,9 +204,11 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 else: raise ValueError(f"MOO objective {obj} not known") - assert len(preds) == len(valid_graphs), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" + assert len(preds) == len( + valid_graphs + ), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" flat_r.append(preds) - + flat_rewards = torch.stack(flat_r, dim=1) return FlatRewards(flat_rewards), is_valid From dff6e147863735e356ea1209450fd14991554bdb Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 00:08:53 -0500 Subject: [PATCH 20/33] fmt --- src/gflownet/tasks/config.py | 2 +- src/gflownet/tasks/qm9_moo.py | 3 --- src/gflownet/tasks/seh_frag_moo.py | 3 --- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 5d56f038..ed0d0a7e 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List @dataclass diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index bd171adf..cf18c376 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -235,9 +235,6 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) - def setup_model(self): if self.cfg.algo.method == "MOQL": self.model = GraphTransformerFragEnvelopeQL( diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 63ae38bd..519f0dd3 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -242,9 +242,6 @@ def setup_task(self): wrap_model=self._wrap_for_mp, ) - def setup_env_context(self): - self.ctx = FragMolBuildingEnvContext(max_frags=self.cfg.algo.max_nodes, num_cond_dim=self.task.num_cond_dim) - def setup_model(self): if self.cfg.algo.method == "MOQL": self.model = GraphTransformerFragEnvelopeQL( From d7ae387372116f8e5dcb4e077427160efc5dc712 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 01:13:58 -0500 Subject: [PATCH 21/33] broadcast back the invalid results --- src/gflownet/tasks/seh_frag.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 0ce1e385..e4af61e4 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -11,6 +11,7 @@ from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset +from torch_geometric.data import Data from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph @@ -72,7 +73,10 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: preds = self.models["seh"](batch).reshape((-1,)).data.cpu() preds[preds.isnan()] = 0 preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1, 1)) - return FlatRewards(preds), is_valid + preds_full = torch.zeros(len(is_valid), 1) + preds_full[is_valid] = preds + assert preds_full.shape == (len(is_valid), 1) + return FlatRewards(preds_full), is_valid SOME_MOLS = [ From 994cbcb675f598a8470d27130f0ced83c257011b Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 01:23:03 -0500 Subject: [PATCH 22/33] add compute_reward_from_graph method to seh --- src/gflownet/tasks/seh_frag_moo.py | 55 ++++++++++++++++-------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 519f0dd3..9e279fe2 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -32,6 +32,27 @@ def safe(f, x, default): return default +def mol2mw(mols: list[RDMol], is_valid: list[bool], default=1000): + molwts = torch.tensor([safe(Descriptors.MolWt, i, default) if v else default for i, v in zip(mols, is_valid)]) + molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + return molwts + + +def mol2sas(mols: list[RDMol], is_valid: list[bool], default=10): + sas = torch.tensor( + [safe(sascore.calculateScore, i, default) if is_valid else default for i, v in zip(mols, is_valid)] + ) + sas = (10 - sas) / 9 # Turn into a [0-1] reward + return sas + + +def mol2qed(mols: list[RDMol], is_valid: list[bool], default=0): + return torch.tensor([safe(QED.qed, i, 0) if v else default for i, v in zip(mols, is_valid)]) + + +aux_tasks = {"qed": mol2qed, "sa": mol2sas, "mw": mol2mw} + + class SEHMOOTask(SEHTask): """Sets up a multiobjective task where the rewards are (functions of): - the binding energy of a molecule to Soluble Epoxide Hydrolases, @@ -178,39 +199,21 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] assert len(graphs) == len(mols) - is_valid = torch.tensor([i is not None for i in graphs]).bool() - valid_graphs = [g for g in graphs if g is not None] - valid_mols = [m for m, g in zip(mols, graphs) if g is not None] - assert len(valid_mols) == len(valid_graphs) - if not is_valid.any(): - return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid - + is_valid = [i is not None for i in graphs] + is_valid_t = torch.tensor(is_valid, dtype=torch.bool) + if not any(is_valid): + return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] for obj in self.objectives: if obj == "seh": - batch = gd.Batch.from_data_list(valid_graphs) - batch.to(self.device) - preds = self.models["seh"](batch).reshape((-1,)).clip(1e-4, 100).data.cpu() - preds[preds.isnan()] = 0 - preds = super().flat_reward_transform(preds) - elif obj == "qed": - preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) - elif obj == "sa": - preds = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) - preds = (10 - preds) / 9 # Turn into a [0-1] reward - elif obj == "mw": - preds = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) - preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + flat_r.append(super().compute_reward_from_graph(graphs, is_valid_t)) else: - raise ValueError(f"MOO objective {obj} not known") - assert len(preds) == len( - valid_graphs - ), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" - flat_r.append(preds) + flat_r.append(aux_tasks[obj](mols, is_valid)) flat_rewards = torch.stack(flat_r, dim=1) - return FlatRewards(flat_rewards), is_valid + assert flat_rewards.shape[0] == len(mols) + return FlatRewards(flat_rewards), is_valid_t class SEHMOOFragTrainer(SEHFragTrainer): From cd2893adbbcbb034b3ff81153390ad7873332bbf Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sun, 11 Feb 2024 01:23:31 -0500 Subject: [PATCH 23/33] use compute_reward_from_graph in seh_moo --- src/gflownet/tasks/seh_frag.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index e4af61e4..f13979e4 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,7 +1,7 @@ import os import shutil import socket -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -63,20 +63,31 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def compute_reward_from_graph(self, graphs: List[Data], is_valid: Optional[Tensor]) -> Tensor: + batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) + if is_valid is None: + is_valid = torch.tensor([i is not None for i in graphs], dtype=torch.bool) + batch.to(self.device) + preds = self.models["seh"](batch).reshape((-1,)).data.cpu() + preds[preds.isnan()] = 0 + preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) + if is_valid is not None: + assert len(is_valid) >= len(preds) + preds_full = torch.zeros(len(is_valid), 1) + preds_full[is_valid] = preds + return preds_full + else: + return preds + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] is_valid = torch.tensor([i is not None for i in graphs]).bool() if not is_valid.any(): return FlatRewards(torch.zeros((0, 1))), is_valid - batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - batch.to(self.device) - preds = self.models["seh"](batch).reshape((-1,)).data.cpu() - preds[preds.isnan()] = 0 - preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1, 1)) - preds_full = torch.zeros(len(is_valid), 1) - preds_full[is_valid] = preds - assert preds_full.shape == (len(is_valid), 1) - return FlatRewards(preds_full), is_valid + + preds = self.compute_reward_from_graph(graphs, is_valid).reshape((-1, 1)) + assert len(preds) == len(mols) + return FlatRewards(preds), is_valid SOME_MOLS = [ From b8930e0e58b23204224fce0f2b36e7b06cd47c17 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 14:01:24 -0500 Subject: [PATCH 24/33] unify trminate and to_close --- src/gflownet/tasks/config.py | 34 +++++++++++++-------- src/gflownet/tasks/qm9_moo.py | 14 ++++++++- src/gflownet/tasks/seh_frag_moo.py | 27 ++++++++-------- src/gflownet/trainer.py | 10 +++--- src/gflownet/utils/multiobjective_hooks.py | 2 -- src/gflownet/utils/multiprocessing_proxy.py | 17 ++--------- 6 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index c64cc4f4..2dc66b53 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -14,17 +14,19 @@ class SEHMOOTaskConfig: Attributes ---------- n_valid : int - The number of valid cond_info tensors to sample + The number of valid cond_info tensors to sample. n_valid_repeats : int - The number of times to repeat the valid cond_info tensors + The number of times to repeat the valid cond_info tensors. objectives : List[str] - The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "wt"]. + The objectives to use for the multi-objective optimization. Should be a subset of ["seh", "qed", "sa", "mw"]. + online_pareto_front : bool + Whether to calculate the pareto front online. """ n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) - + online_pareto_front: bool = True @dataclass class QM9TaskConfig: @@ -34,18 +36,24 @@ class QM9TaskConfig: @dataclass class QM9MOOTaskConfig: - use_steer_thermometer: bool = False - preference_type: Optional[str] = "dirichlet" - focus_type: Optional[str] = None - focus_dirs_listed: Optional[List[List[float]]] = None - focus_cosim: float = 0.0 - focus_limit_coef: float = 1.0 - focus_model_training_limits: Optional[Tuple[int, int]] = None - focus_model_state_space_res: Optional[int] = None - max_train_it: Optional[int] = None + """ + Config for the QM9MooTask + + Attributes + ---------- + n_valid : int + The number of valid cond_info tensors to sample. + n_valid_repeats : int + The number of times to repeat the valid cond_info tensors. + objectives : List[str] + The objectives to use for the multi-objective optimization. Should be a subset of ["gap", "qed", "sa"]. + online_pareto_front : bool + Whether to calculate the pareto front online. + """ n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) + online_pareto_front: bool = True @dataclass diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 57632063..7df45096 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -18,7 +18,7 @@ from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences -from gflownet.utils.multiobjective_hooks import TopKHook +from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook def safe(f, x, default): @@ -259,6 +259,18 @@ def setup_task(self): def setup(self): super().setup() + if self.cfg.task.seh_moo.online_pareto_front: + self.sampling_hooks.append( + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, + ) + ) + self.to_terminate.append(self.sampling_hooks[-1].terminate) # instantiate preference and focus conditioning vectors for validation tcfg = self.cfg.task.qm9_moo diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 893a4b0c..96d697c7 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -261,17 +261,18 @@ def setup_model(self): def setup(self): super().setup() - self.sampling_hooks.append( - MultiObjectiveStatsHook( - 256, - self.cfg.log_dir, - compute_igd=True, - compute_pc_entropy=True, - compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, - focus_cosim=self.cfg.cond.focus_region.focus_cosim, + if self.cfg.task.seh_moo.online_pareto_front: + self.sampling_hooks.append( + MultiObjectiveStatsHook( + 256, + self.cfg.log_dir, + compute_igd=True, + compute_pc_entropy=True, + compute_focus_accuracy=True if self.cfg.cond.focus_region.focus_type is not None else False, + focus_cosim=self.cfg.cond.focus_region.focus_cosim, + ) ) - ) - self.to_close.append(self.sampling_hooks[-1].keep_alive) + self.to_terminate.append(self.sampling_hooks[-1].terminate) # instantiate preference and focus conditioning vectors for validation n_obj = len(self.cfg.task.seh_moo.objectives) @@ -360,9 +361,11 @@ def _save_state(self, it): def run(self): super().run() for hook in self.sampling_hooks: - if hasattr(hook, "terminate"): + if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: hook.terminate() - + + ;for terminate in self.to_terminate: + terminate() class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index bd748e2e..c77729c6 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -115,7 +115,7 @@ def __init__(self, hps: Dict[str, Any], print_hps=True): The torch device of the main worker. """ self.print_hps = print_hps - self.to_close: List[Closable] = [] + self.to_terminate: List[Closable] = [] # self.setup should at least set these up: self.training_data: Dataset self.test_data: Dataset @@ -188,13 +188,13 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - placeholder, keepalive = mp_object_wrapper( + placeholder = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) - self.to_close.append(keepalive) + self.to_terminate.append(placeholder.terminate) return placeholder, torch.device("cpu") else: return obj, self.device @@ -428,9 +428,9 @@ def log(self, info, index, key): self._summary_writer.add_scalar(f"{key}_{k}", v, index) def close(self): - while len(self.to_close) > 0: + while len(self.to_terminate) > 0: try: - i = self.to_close.pop() + i = self.to_terminate.pop() i.close() except Exception as e: print(e) diff --git a/src/gflownet/utils/multiobjective_hooks.py b/src/gflownet/utils/multiobjective_hooks.py index 6dda97ee..115bef3a 100644 --- a/src/gflownet/utils/multiobjective_hooks.py +++ b/src/gflownet/utils/multiobjective_hooks.py @@ -11,7 +11,6 @@ from torch import Tensor from gflownet.utils import metrics -from gflownet.utils.multiprocessing_proxy import KeepAlive class MultiObjectiveStatsHook: @@ -60,7 +59,6 @@ def __init__( self.log_path = pathlib.Path(log_dir) / "pareto.pt" self.pareto_thread = threading.Thread(target=self._run_pareto_accumulation, daemon=True) self.pareto_thread.start() - self.keep_alive = KeepAlive(self.stop) def _hsri(self, x): assert x.ndim == 2, "x should have shape (num points, num objectives)" diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index d9be2558..6d26f3cf 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -62,16 +62,6 @@ def __len__(self): return self.out_queue.get() -class KeepAlive: - def __init__(self, flag): - self.flag = flag - - def close(self): - self.flag.set() - - def __del__(self): - self.close() - class MPObjectProxy: """This class maintains a reference to some object and @@ -114,7 +104,6 @@ def __init__(self, obj, num_workers: int, cast_types: tuple, pickle_messages: bo self.device = torch.device("cpu") self.cuda_types = (torch.Tensor,) + cast_types self.stop = threading.Event() - self.keepalive = KeepAlive(self.stop) self.thread = threading.Thread(target=self.run, daemon=True) self.thread.start() @@ -166,7 +155,8 @@ def run(self): else: msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) - + def terminate(self): + self.stop.set() def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): """Construct a multiprocessing object proxy for torch DataLoaders so @@ -203,5 +193,4 @@ def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = Fals A placeholder object whose method calls route arguments to the main process """ - x = MPObjectProxy(obj, num_workers, cast_types, pickle_messages) - return x.placeholder, x.keepalive + return MPObjectProxy(obj, num_workers, cast_types, pickle_messages) From 673a8801f95495b24891fb7236099bbef715def6 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 15:28:44 -0500 Subject: [PATCH 25/33] fmt --- src/gflownet/tasks/config.py | 2 + src/gflownet/tasks/qm9/qm9.py | 24 +++++++++--- src/gflownet/tasks/qm9_moo.py | 42 +++++---------------- src/gflownet/tasks/seh_frag_moo.py | 8 ---- src/gflownet/trainer.py | 20 +++++----- src/gflownet/utils/multiprocessing_proxy.py | 3 +- 6 files changed, 41 insertions(+), 58 deletions(-) diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index 40634a8f..7e7df30d 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -28,6 +28,7 @@ class SEHMOOTaskConfig: objectives: List[str] = field(default_factory=lambda: ["seh", "qed", "sa", "mw"]) online_pareto_front: bool = True + @dataclass class QM9TaskConfig: h5_path: str = "./data/qm9/qm9.h5" # see src/gflownet/data/qm9.py @@ -51,6 +52,7 @@ class QM9MOOTaskConfig: online_pareto_front : bool Whether to calculate the pareto front online. """ + n_valid: int = 15 n_valid_repeats: int = 128 objectives: List[str] = field(default_factory=lambda: ["gap", "qed", "sa"]) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 9541334d..0cf6d7b2 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch @@ -84,16 +84,28 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] - is_valid = torch.tensor([i is not None for i in graphs]).bool() - if not is_valid.any(): - return FlatRewards(torch.zeros((0, 1))), is_valid + def compute_reward_from_graph(self, graphs: List[gd.Data], is_valid: Optional[Tensor]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) batch.to(self.device) preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape((-1, 1)) + if is_valid is not None: + assert len(is_valid) >= len(preds) + preds_full = torch.zeros(len(is_valid), 1) + preds_full[is_valid] = preds + return preds_full + else: + return preds + + def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] + is_valid = torch.tensor([i is not None for i in graphs]).bool() + if not is_valid.any(): + return FlatRewards(torch.zeros((0, 1))), is_valid + + preds = self.compute_reward_from_graph(graphs, is_valid).reshape((-1, 1)) + assert len(preds) == len(mols) return FlatRewards(preds), is_valid diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 4734d6c0..a963f75f 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -17,7 +17,7 @@ from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer -from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, safe +from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks, safe from gflownet.trainer import FlatRewards, RewardScalar from gflownet.utils import metrics, sascore from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences @@ -171,39 +171,21 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] assert len(graphs) == len(mols) - is_valid = torch.tensor([i is not None for i in graphs]).bool() - valid_graphs = [g for g in graphs if g is not None] - valid_mols = [m for m, g in zip(mols, graphs) if g is not None] - assert len(valid_mols) == len(valid_graphs) - if not is_valid.any(): - return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid + is_valid = [i is not None for i in graphs] + is_valid_t = torch.tensor(is_valid, dtype=torch.bool) + if not any(is_valid): + return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] for obj in self.objectives: if obj == "gap": - batch = gd.Batch.from_data_list(valid_graphs) - batch.to(self.device) - preds = self.models["mxmnet_gap"](batch) - preds = preds.reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] - preds[preds.isnan()] = 1 - preds = super().flat_reward_transform(preds) - elif obj == "qed": - preds = torch.tensor([safe(QED.qed, i, 0) for i in valid_mols]) - elif obj == "sa": - preds = torch.tensor([safe(sascore.calculateScore, i, 10) for i in valid_mols]) - preds = (10 - preds) / 9 # Turn into a [0-1] reward - elif obj == "mw": - preds = torch.tensor([safe(Descriptors.MolWt, i, 1000) for i in valid_mols]) - preds = ((300 - preds) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 + flat_r.append(super().compute_reward_from_graph(graphs, is_valid_t)) else: - raise ValueError(f"MOO objective {obj} not known") - assert len(preds) == len( - valid_graphs - ), f"len of reward {obj} is {len(preds)} not the expected {len(valid_graphs)}" - flat_r.append(preds) + flat_r.append(aux_tasks[obj](mols, is_valid)) flat_rewards = torch.stack(flat_r, dim=1) - return FlatRewards(flat_rewards), is_valid + assert flat_rewards.shape[0] == len(mols) + return FlatRewards(flat_rewards), is_valid_t class QM9MOOTrainer(QM9GapTrainer): @@ -349,9 +331,3 @@ def _save_state(self, it): if self.task.focus_cond is not None and self.task.focus_cond.focus_model is not None: self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) - - def run(self): - super().run() - for hook in self.sampling_hooks: - if hasattr(hook, "terminate"): - hook.terminate() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 9a5bc90c..366a03c8 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -356,14 +356,6 @@ def _save_state(self, it): self.task.focus_cond.focus_model.save(pathlib.Path(self.cfg.log_dir)) return super()._save_state(it) - def run(self): - super().run() - for hook in self.sampling_hooks: - if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: - hook.terminate() - - ;for terminate in self.to_terminate: - terminate() class RepeatedCondInfoDataset: def __init__(self, cond_info_vectors, repeat): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index c77729c6..6848bf2b 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -398,12 +398,20 @@ def run(self, logger=None): logger.info("Final generation steps completed - " + " ".join(f"{k}:{v:.2f}" for k, v in final_info.items())) self.log(final_info, num_training_steps, "final") - # for pypy and other GC havers + # for pypy and other GC having implementations, we need to manually clean up del train_dl del valid_dl if self.cfg.num_final_gen_steps: del final_dl + def terminate(self): + for hook in self.sampling_hooks: + if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: + hook.terminate() + + for terminate in self.to_terminate: + terminate() + def _save_state(self, it): state = { "models_state_dict": [self.model.state_dict()], @@ -427,16 +435,8 @@ def log(self, info, index, key): for k, v in info.items(): self._summary_writer.add_scalar(f"{key}_{k}", v, index) - def close(self): - while len(self.to_terminate) > 0: - try: - i = self.to_terminate.pop() - i.close() - except Exception as e: - print(e) - def __del__(self): - self.close() + self.terminate() def cycle(it): diff --git a/src/gflownet/utils/multiprocessing_proxy.py b/src/gflownet/utils/multiprocessing_proxy.py index 6d26f3cf..df13b565 100644 --- a/src/gflownet/utils/multiprocessing_proxy.py +++ b/src/gflownet/utils/multiprocessing_proxy.py @@ -62,7 +62,6 @@ def __len__(self): return self.out_queue.get() - class MPObjectProxy: """This class maintains a reference to some object and creates a `placeholder` attribute which can be safely passed to @@ -155,9 +154,11 @@ def run(self): else: msg = self.to_cpu(result) self.out_queues[qi].put(self.encode(msg)) + def terminate(self): self.stop.set() + def mp_object_wrapper(obj, num_workers, cast_types, pickle_messages: bool = False): """Construct a multiprocessing object proxy for torch DataLoaders so that it does not need to be copied in every worker's memory. For example, From 71d62d5f91441c09909bfc6db29fea470f911d1b Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 15:29:37 -0500 Subject: [PATCH 26/33] ft --- src/gflownet/tasks/qm9_moo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index a963f75f..7408269d 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch_geometric.data as gd -from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor from torch.utils.data import Dataset @@ -17,9 +16,9 @@ from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.tasks.qm9.qm9 import QM9GapTask, QM9GapTrainer -from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks, safe +from gflownet.tasks.seh_frag_moo import RepeatedCondInfoDataset, aux_tasks from gflownet.trainer import FlatRewards, RewardScalar -from gflownet.utils import metrics, sascore +from gflownet.utils import metrics from gflownet.utils.conditioning import FocusRegionConditional, MultiObjectiveWeightedPreferences from gflownet.utils.multiobjective_hooks import MultiObjectiveStatsHook, TopKHook from gflownet.utils.transforms import to_logreward From 4f058e58279593ae5c79e871b88d91122a0b0043 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 15:29:52 -0500 Subject: [PATCH 27/33] f --- src/gflownet/tasks/qm9/qm9.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 0cf6d7b2..71900e97 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,4 +1,3 @@ -import os from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np From f947bb34c4943a4ef5416cb0b736bd82c6d99745 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 16:42:29 -0500 Subject: [PATCH 28/33] fix typo --- src/gflownet/online_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 81fd7ae7..c815cce9 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -66,7 +66,7 @@ def _opt(self, params, lr=None, momentum=None): eps=self.cfg.opt.adam_eps, ) - raise NotImplementedError(f"{self.opt.opt} is not implemented") + raise NotImplementedError(f"{self.cfg.opt.opt} is not implemented") def setup(self): super().setup() From d36952b044171cffbb314b4e79e7fd6ad270eec7 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 21:16:50 -0500 Subject: [PATCH 29/33] fix runtime errors --- pyproject.toml | 1 + src/gflownet/algo/graph_sampling.py | 1 + src/gflownet/algo/trajectory_balance.py | 10 ++++++++++ src/gflownet/tasks/qm9/qm9.py | 16 +++++----------- src/gflownet/tasks/{ => qm9}/qm9_moo.py | 5 +++-- src/gflownet/tasks/seh_frag.py | 17 ++++------------- src/gflownet/tasks/seh_frag_moo.py | 8 ++++---- src/gflownet/trainer.py | 6 +++--- 8 files changed, 31 insertions(+), 33 deletions(-) rename src/gflownet/tasks/{ => qm9}/qm9_moo.py (99%) diff --git a/pyproject.toml b/pyproject.toml index d588c636..58a83288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ dependencies = [ "pyro-ppl", "gpytorch", "omegaconf>=2.3", + "pandas", # needed for QM9 and HDF5 support. ] [project.optional-dependencies] diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 5a6ace57..7ad4fc0a 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -246,6 +246,7 @@ def sample_backward_from_graphs( def not_done(lst): return [e for i, e in enumerate(lst) if not done[i]] + # TODO: This should be doable. if random_action_prob > 0: raise NotImplementedError("Random action not implemented for backward sampling") diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 3a98423f..634dd752 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,3 +1,4 @@ +from copy import deepcopy from typing import Optional, Tuple import networkx as nx @@ -214,6 +215,15 @@ def create_training_data_from_graphs( ] + [1] traj["bck_logprobs"] = (1 / torch.tensor(n_back).float()).log().to(self.ctx.device) traj["result"] = traj["traj"][-1][0] + if self.cfg.do_parameterize_p_b: + traj["bck_a"] = [GraphAction(GraphActionType.Stop)] + [self.env.reverse(g, a) for g, a in traj["traj"]] + # There needs to be an additonal node when we're parameterizing P_B, + # See sampling with parametrized P_B + traj["traj"].append(deepcopy(traj["traj"][-1])) + traj["is_sink"] = [0 for _ in traj["traj"]] + traj["is_sink"][-1] = 1 + traj["is_sink"][-2] = 1 + assert len(traj["bck_a"]) == len(traj["traj"]) == len(traj["is_sink"]) return trajs def get_idempotent_actions(self, g: Graph, gd: gd.Data, gp: Graph, action: GraphAction, return_aidx: bool = True): diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 71900e97..2b0944f3 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -83,19 +83,13 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) - def compute_reward_from_graph(self, graphs: List[gd.Data], is_valid: Optional[Tensor]) -> Tensor: + def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) batch.to(self.device) preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 - preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape((-1, 1)) - if is_valid is not None: - assert len(is_valid) >= len(preds) - preds_full = torch.zeros(len(is_valid), 1) - preds_full[is_valid] = preds - return preds_full - else: - return preds + preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape(-1,) + return preds def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] @@ -103,8 +97,8 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: if not is_valid.any(): return FlatRewards(torch.zeros((0, 1))), is_valid - preds = self.compute_reward_from_graph(graphs, is_valid).reshape((-1, 1)) - assert len(preds) == len(mols) + preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) + assert len(preds) == is_valid.sum() return FlatRewards(preds), is_valid diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9/qm9_moo.py similarity index 99% rename from src/gflownet/tasks/qm9_moo.py rename to src/gflownet/tasks/qm9/qm9_moo.py index 7408269d..ffc03364 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9/qm9_moo.py @@ -172,18 +172,19 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: assert len(graphs) == len(mols) is_valid = [i is not None for i in graphs] is_valid_t = torch.tensor(is_valid, dtype=torch.bool) + if not any(is_valid): return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] for obj in self.objectives: if obj == "gap": - flat_r.append(super().compute_reward_from_graph(graphs, is_valid_t)) + flat_r.append(super().compute_reward_from_graph(graphs)) else: flat_r.append(aux_tasks[obj](mols, is_valid)) flat_rewards = torch.stack(flat_r, dim=1) - assert flat_rewards.shape[0] == len(mols) + assert flat_rewards.shape[0] == is_valid_t.sum() return FlatRewards(flat_rewards), is_valid_t diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index f13979e4..466618ef 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -63,21 +63,12 @@ def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Ten def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) - def compute_reward_from_graph(self, graphs: List[Data], is_valid: Optional[Tensor]) -> Tensor: + def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) - if is_valid is None: - is_valid = torch.tensor([i is not None for i in graphs], dtype=torch.bool) batch.to(self.device) preds = self.models["seh"](batch).reshape((-1,)).data.cpu() preds[preds.isnan()] = 0 - preds = self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) - if is_valid is not None: - assert len(is_valid) >= len(preds) - preds_full = torch.zeros(len(is_valid), 1) - preds_full[is_valid] = preds - return preds_full - else: - return preds + return self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] @@ -85,8 +76,8 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: if not is_valid.any(): return FlatRewards(torch.zeros((0, 1))), is_valid - preds = self.compute_reward_from_graph(graphs, is_valid).reshape((-1, 1)) - assert len(preds) == len(mols) + preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) + assert len(preds) == is_valid.sum() return FlatRewards(preds), is_valid diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 366a03c8..23adc295 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -33,21 +33,21 @@ def safe(f, x, default): def mol2mw(mols: list[RDMol], is_valid: list[bool], default=1000): - molwts = torch.tensor([safe(Descriptors.MolWt, i, default) if v else default for i, v in zip(mols, is_valid)]) + molwts = torch.tensor([safe(Descriptors.MolWt, i, default) for i, v in zip(mols, is_valid) if v]) molwts = ((300 - molwts) / 700 + 1).clip(0, 1) # 1 until 300 then linear decay to 0 until 1000 return molwts def mol2sas(mols: list[RDMol], is_valid: list[bool], default=10): sas = torch.tensor( - [safe(sascore.calculateScore, i, default) if is_valid else default for i, v in zip(mols, is_valid)] + [safe(sascore.calculateScore, i, default) for i, v in zip(mols, is_valid) if v] ) sas = (10 - sas) / 9 # Turn into a [0-1] reward return sas def mol2qed(mols: list[RDMol], is_valid: list[bool], default=0): - return torch.tensor([safe(QED.qed, i, 0) if v else default for i, v in zip(mols, is_valid)]) + return torch.tensor([safe(QED.qed, i, 0) for i, v in zip(mols, is_valid) if v]) aux_tasks = {"qed": mol2qed, "sa": mol2sas, "mw": mol2mw} @@ -207,7 +207,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: flat_r: List[Tensor] = [] for obj in self.objectives: if obj == "seh": - flat_r.append(super().compute_reward_from_graph(graphs, is_valid_t)) + flat_r.append(super().compute_reward_from_graph(graphs)) else: flat_r.append(aux_tasks[obj](mols, is_valid)) diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 6848bf2b..e60d742e 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -188,14 +188,14 @@ def _wrap_for_mp(self, obj, send_to_device=False): if send_to_device: obj.to(self.device) if self.cfg.num_workers > 0 and obj is not None: - placeholder = mp_object_wrapper( + wapper = mp_object_wrapper( obj, self.cfg.num_workers, cast_types=(gd.Batch, GraphActionCategorical, SeqBatch), pickle_messages=self.cfg.pickle_mp_messages, ) - self.to_terminate.append(placeholder.terminate) - return placeholder, torch.device("cpu") + self.to_terminate.append(wapper.terminate) + return wapper.placeholder, torch.device("cpu") else: return obj, self.device From ae7d74029e59665747ff7fbc9658c45d8db02db6 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Fri, 16 Feb 2024 21:24:15 -0500 Subject: [PATCH 30/33] fmt --- src/gflownet/algo/trajectory_balance.py | 4 ++-- src/gflownet/tasks/qm9/qm9.py | 10 ++++++++-- src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 4 +--- 4 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 634dd752..75e5471f 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import networkx as nx import numpy as np @@ -207,7 +207,7 @@ def create_training_data_from_graphs( return self.graph_sampler.sample_backward_from_graphs( graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob ) - trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] + trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: n_back = [ self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 2b0944f3..488df326 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np import torch @@ -88,7 +88,13 @@ def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: batch.to(self.device) preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 - preds = self.flat_reward_transform(preds).clip(1e-4, 2).reshape(-1,) + preds = ( + self.flat_reward_transform(preds) + .clip(1e-4, 2) + .reshape( + -1, + ) + ) return preds def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 466618ef..91d65818 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,7 +1,7 @@ import os import shutil import socket -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Tuple, Union import numpy as np import torch diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 23adc295..31d8f769 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -39,9 +39,7 @@ def mol2mw(mols: list[RDMol], is_valid: list[bool], default=1000): def mol2sas(mols: list[RDMol], is_valid: list[bool], default=10): - sas = torch.tensor( - [safe(sascore.calculateScore, i, default) for i, v in zip(mols, is_valid) if v] - ) + sas = torch.tensor([safe(sascore.calculateScore, i, default) for i, v in zip(mols, is_valid) if v]) sas = (10 - sas) / 9 # Turn into a [0-1] reward return sas From 922dc8b3a6c417f5457e183e3e67727232f1825d Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sat, 17 Feb 2024 00:39:16 -0500 Subject: [PATCH 31/33] close hdf5 gracefully --- src/gflownet/data/qm9.py | 10 +++++++++- src/gflownet/tasks/qm9/qm9.py | 2 ++ src/gflownet/tasks/qm9/qm9_moo.py | 2 ++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index 46c44ad2..f35bdb14 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -14,9 +14,13 @@ class QM9Dataset(Dataset): def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], split_seed=142857, ratio=0.9): if h5_file is not None: - self.df = pd.HDFStore(h5_file, "r")["df"] + + self.hdf = pd.HDFStore(h5_file, "r") + self.df = self.hdf["df"] + self.is_hdf = True elif xyz_file is not None: self.df = load_tar(xyz_file) + self.is_hdf = False else: raise ValueError("Either h5_file or xyz_file must be provided") rng = np.random.default_rng(split_seed) @@ -47,6 +51,10 @@ def __getitem__(self, idx): torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), ) + def terminate(self): + if self.is_hdf: + self.hdf.close() + def load_tar(xyz_file): labels = ["rA", "rB", "rC", "mu", "alpha", "homo", "lumo", "gap", "r2", "zpve", "U0", "U", "H", "G", "Cv"] diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 488df326..09410579 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -147,6 +147,8 @@ def setup_env_context(self): def setup_data(self): self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=["gap"]) self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=["gap"]) + self.to_terminate.append(self.training_data.terminate) + self.to_terminate.append(self.test_data.terminate ) def setup_task(self): self.task = QM9GapTask( diff --git a/src/gflownet/tasks/qm9/qm9_moo.py b/src/gflownet/tasks/qm9/qm9_moo.py index ffc03364..cb0e8277 100644 --- a/src/gflownet/tasks/qm9/qm9_moo.py +++ b/src/gflownet/tasks/qm9/qm9_moo.py @@ -307,6 +307,8 @@ def setup(self): def setup_data(self): self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=self.cfg.task.qm9_moo.objectives) self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=self.cfg.task.qm9_moo.objectives) + self.to_terminate.append(self.training_data.terminate) + self.to_terminate.append(self.test_data.terminate) def build_callbacks(self): # We use this class-based setup to be compatible with the DeterminedAI API, but no direct From b5ddddf6b1b889feca9ae09262dd408df7ff1b50 Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sat, 17 Feb 2024 00:40:12 -0500 Subject: [PATCH 32/33] fmt --- src/gflownet/tasks/qm9/qm9.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/tasks/qm9/qm9.py b/src/gflownet/tasks/qm9/qm9.py index 09410579..d66f571a 100644 --- a/src/gflownet/tasks/qm9/qm9.py +++ b/src/gflownet/tasks/qm9/qm9.py @@ -148,7 +148,7 @@ def setup_data(self): self.training_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=True, targets=["gap"]) self.test_data = QM9Dataset(self.cfg.task.qm9.h5_path, train=False, targets=["gap"]) self.to_terminate.append(self.training_data.terminate) - self.to_terminate.append(self.test_data.terminate ) + self.to_terminate.append(self.test_data.terminate) def setup_task(self): self.task = QM9GapTask( From a715adb797177a76a9988e6fc56ef699fe4c5fab Mon Sep 17 00:00:00 2001 From: Sobhan Mohammadpour Date: Sat, 17 Feb 2024 20:06:51 -0500 Subject: [PATCH 33/33] revert default num_graph_out --- src/gflownet/models/graph_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 79d42cbf..8c3993f0 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -169,7 +169,7 @@ def __init__( self, env_ctx, cfg: Config, - num_graph_out=2, + num_graph_out=1, do_bck=False, ): """See `GraphTransformer` for argument values"""