diff --git a/README.md b/README.md index 1745d09e..ec90e69e 100644 --- a/README.md +++ b/README.md @@ -11,34 +11,14 @@ GFlowNet-related training and environment code on graphs. **Primer** -[GFlowNet](https://yoshuabengio.org/2022/03/05/generative-flow-networks/), short for Generative Flow Network, is a novel generative modeling framework, particularly suited for discrete, combinatorial objects. Here in particular it is implemented for graph generation. +GFlowNet [[1]](https://yoshuabengio.org/2022/03/05/generative-flow-networks/), [[2]](https://www.gflownet.org/), [[3]](https://github.com/zdhNarsil/Awesome-GFlowNets), short for Generative Flow Network, is a novel generative modeling framework, particularly suited for discrete, combinatorial objects. Here in particular it is implemented for graph generation. -The idea behind GFN is to estimate flows in a (graph-theoretic) directed acyclic network*. The network represents all possible ways of constructing an object, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a _trajectory_. *Perhaps confusingly, the _network_ in GFN refers to the state space, not a neural network architecture. +The idea behind GFN is to estimate flows in a (graph-theoretic) directed acyclic network*. The network represents all possible ways of constructing objects, and so knowing the flow gives us a policy which we can follow to sequentially construct objects. Such a sequence of partially constructed objects is a _trajectory_. *Perhaps confusingly, the _network_ in GFN refers to the state space, not a neural network architecture. -Here the objects we construct are themselves graphs (e.g. graphs of atoms), which are constructed node by node. To make policy predictions, we use a graph neural network. This GNN outputs per-node logits (e.g. add an atom to this atom, or add a bond between these two atoms), as well as per-graph logits (e.g. stop/"done constructing this object"). +The main focus of this library (although it can do other things) is to construct graphs (e.g. graphs of atoms), which are constructed node by node. To make policy predictions, we use a graph neural network. This GNN outputs per-node logits (e.g. add an atom to this atom, or add a bond between these two atoms), as well as per-graph logits (e.g. stop/"done constructing this object"). -The GNN model can be trained on a mix of existing data (offline) and self-generated data (online), the latter being obtained by querying the model sequentially to obtain trajectories. For offline data, we can easily generate trajectories since we know the end state. +This library supports a variety of GFN algorithms (as well as some baselines), and supports training on a mix of existing data (offline) and self-generated data (online), the latter being obtained by querying the model sequentially to obtain trajectories. -## Repo overview - -- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. -- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. -- [envs](src/gflownet/envs), contains environment classes; a graph-building environment base, and a molecular graph context class. The base environment is agnostic to what kind of graph is being made, and the context class specifies mappings from graphs to objects (e.g. molecules) and torch geometric Data. -- [examples](docs/examples), contains simple example implementations of GFlowNet. -- [models](src/gflownet/models), contains model definitions. -- [tasks](src/gflownet/tasks), contains training code. - - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. - - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein - - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). -- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). -- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. -- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. - -See [implementation notes](docs/implementation_notes.md) for more. - -## Getting started - -A good place to get started is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`). ## Installation @@ -62,6 +42,30 @@ pip install git+https://github.com/recursionpharma/gflownet.git@v0.0.10 --find-l If package dependencies seem not to work, you may need to install the exact frozen versions listed `requirements/`, i.e. `pip install -r requirements/main-3.10.txt`. +## Getting started + +A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`). + +For a gentler introduction to the library, see [Getting Started](docs/getting_started.md). For a more in-depth look at the library, see [Implementation Notes](docs/implementation_notes.md). + +## Repo overview + +- [algo](src/gflownet/algo), contains GFlowNet algorithms implementations ([Trajectory Balance](https://arxiv.org/abs/2201.13259), [SubTB](https://arxiv.org/abs/2209.12782), [Flow Matching](https://arxiv.org/abs/2106.04399)), as well as some baselines. These implement how to sample trajectories from a model and compute the loss from trajectories. +- [data](src/gflownet/data), contains dataset definitions, data loading and data sampling utilities. +- [envs](src/gflownet/envs), contains environment classes; the base environment is agnostic to what kind of graph is being made, and context classes specify mappings from graphs to objects (e.g. molecules) and torch geometric Data. +- [examples](docs/examples), contains simple example implementations of GFlowNet. +- [models](src/gflownet/models), contains model definitions. +- [tasks](src/gflownet/tasks), contains training code. + - [qm9](src/gflownet/tasks/qm9/qm9.py), temperature-conditional molecule sampler based on QM9's HOMO-LUMO gap data as a reward. + - [seh_frag](src/gflownet/tasks/seh_frag.py), reproducing Bengio et al. 2021, fragment-based molecule design targeting the sEH protein + - [seh_frag_moo](src/gflownet/tasks/seh_frag_moo.py), same as the above, but with multi-objective optimization (incl. QED, SA, and molecule weight objectives). +- [utils](src/gflownet/utils), contains utilities (multiprocessing, metrics, conditioning). +- [`trainer.py`](src/gflownet/trainer.py), defines a general harness for training GFlowNet models. +- [`online_trainer.py`](src/gflownet/online_trainer.py), defines a typical online-GFN training loop. + +See [implementation notes](docs/implementation_notes.md) for more. + + ## Developing & Contributing External contributions are welcome. @@ -73,3 +77,5 @@ pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121. We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. + +For more information, see [Contributing](docs/contributing.md). diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 00000000..b63b082f --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,38 @@ +# Contributing + +Contributions to the repository are welcome, and we encourage you to open issues and pull requests. In general, it is recommended to fork this repository and open a pull request from your fork to the `trunk` branch. PRs are encouraged to be short and focused, and to include tests and documentation where appropriate. + +## Installation + +To install the developers dependencies run: +``` +pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html +``` + +## Dependencies + +Dependencies are defined in `pyproject.toml`, and frozen versions that are known to work are provided in `requirements/`. + +To regenerate the frozen versions, run `./generate_requirements.sh `. See comments within. + +## Linting and testing + +We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. +To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. + +`tox` itself runs many linters, but the most important ones are `black`, `ruff`, `isort`, and `mypy`. The full list +of linting tools is found in `.pre-commit-config.yaml`, while `tox.ini` defines the environments under which these +linters (as well as tests) are run. + +## Github Actions + +We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`. + +The cascade of events is as follows: +- For `build-and-test`, `tox -> testenv:py310 -> pytest` is run. +- For `code-quality`, `tox -e style -> testenv:style -> pre-commit -> {isort, black, mypy, bandit, ruff, & others}`. This and the "others" are defined in `.pre-commit-config.yaml` and include things like checking for secrets and trailing whitespace. + +## Style Guide + +On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). +Our docstrings follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) format, and we use type hints throughout the codebase. diff --git a/docs/getting_started.md b/docs/getting_started.md new file mode 100644 index 00000000..8ec37f3e --- /dev/null +++ b/docs/getting_started.md @@ -0,0 +1,11 @@ +# Getting Started + +For an introduction to the library, see [this colab notebook](https://colab.research.google.com/drive/1wANyo6Y-ceYEto9-p50riCsGRb_6U6eH). + +For an introduction to using `wandb` to log experiments, see [this demo](../src/gflownet/hyperopt/wandb_demo). + +For more general introductions to GFlowNets, check out the following: +- The 2023 [GFlowNet workshop](https://gflownet.org/) has several introductory talks and colab tutorials. +- This high-level [GFlowNet colab tutorial](https://colab.research.google.com/drive/1fUMwgu2OhYpQagpzU5mhe9_Esib3Q2VR) (updated versions of which were written for the 2023 workshop, in particular for continuous GFNs). + +A good place to get started immediately is with the [sEH fragment-based MOO task](src/gflownet/tasks/seh_frag_moo.py). The file `seh_frag_moo.py` is runnable as-is (although you may want to change the default configuration in `main()`). \ No newline at end of file diff --git a/src/gflownet/__init__.py b/src/gflownet/__init__.py index 06bb6ba5..6cb8f979 100644 --- a/src/gflownet/__init__.py +++ b/src/gflownet/__init__.py @@ -1,17 +1,18 @@ -from typing import Dict, List, NewType, Optional, Tuple +from typing import Any, Dict, List, NewType, Optional, Tuple import torch_geometric.data as gd -from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor, nn from .config import Config -# This type represents an unprocessed list of reward signals/conditioning information -FlatRewards = NewType("FlatRewards", Tensor) # type: ignore +# This type represents a set of scalar properties attached to each object in a batch. +ObjectProperties = NewType("ObjectProperties", Tensor) # type: ignore -# This type represents the outcome for a multi-objective task of -# converting FlatRewards to a scalar, e.g. (sum R_i omega_i) ** beta -RewardScalar = NewType("RewardScalar", Tensor) # type: ignore +# This type represents log-scalars, in particular log-rewards at the scale we operate with with GFlowNets +# for example, converting a reward ObjectProperties to a log-scalar with log [(sum R_i omega_i) ** beta] +LogScalar = NewType("LogScalar", Tensor) # type: ignore +# This type represents linear-scalars +LinScalar = NewType("LinScalar", Tensor) # type: ignore class GFNAlgorithm: @@ -75,15 +76,15 @@ def get_random_action_prob(self, it: int): class GFNTask: - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: """Combines a minibatch of reward signal vectors and conditional information into a scalar reward. Parameters ---------- cond_info: Dict[str, Tensor] A dictionary with various conditional informations (e.g. temperature) - flat_reward: FlatRewards - A 2d tensor where each row represents a series of flat rewards. + obj_props: ObjectProperties + A 2d tensor where each row represents a series of object properties. Returns ------- @@ -92,18 +93,35 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat """ raise NotImplementedError() - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: - """Compute the flat rewards of mols according the the tasks' proxies + def compute_obj_properties(self, objs: List[Any]) -> Tuple[ObjectProperties, Tensor]: + """Compute the flat rewards of objs according the the tasks' proxies Parameters ---------- - mols: List[RDMol] - A list of RDKit molecules. + objs: List[Any] + A list of n objects. Returns ------- - reward: FlatRewards - A 2d tensor, a vector of scalar reward for valid each molecule. + obj_probs: ObjectProperties + A 2d tensor (m, p), a vector of scalar properties for the m <= n valid objects. is_valid: Tensor - A 1d tensor, a boolean indicating whether the molecule is valid. + A 1d tensor (n,), a boolean indicating whether each object is valid. + """ + raise NotImplementedError() + + def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: + """Sample conditional information for n objects + + Parameters + ---------- + n: int + The number of objects to sample conditional information for. + train_it: int + The training iteration number. + + Returns + ------- + cond_info: Dict[str, Tensor] + A dictionary with various conditional informations (e.g. temperature) """ raise NotImplementedError() diff --git a/src/gflownet/algo/advantage_actor_critic.py b/src/gflownet/algo/advantage_actor_critic.py index c657e88e..7077a9d1 100644 --- a/src/gflownet/algo/advantage_actor_critic.py +++ b/src/gflownet/algo/advantage_actor_critic.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd @@ -16,7 +15,6 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, - rng: np.random.RandomState, cfg: Config, ): """Advantage Actor-Critic implementation, see @@ -34,15 +32,12 @@ def __init__( A graph environment. ctx: GraphBuildingEnvContext A context. - rng: np.random.RandomState - rng used to take random actions cfg: Config The experiment configuration """ self.ctx = ctx self.env = env - self.rng = rng self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes self.illegal_action_logreward = cfg.algo.illegal_action_logreward @@ -54,7 +49,7 @@ def __init__( # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -82,7 +77,7 @@ def create_training_data_from_own_samples( """ dev = get_worker_device() cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) return data def create_training_data_from_graphs(self, graphs): @@ -152,12 +147,12 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # of length 4, trajectory 1 of length 3, and so on. batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) - # Forward pass of the model, returns a GraphActionCategorical and per molecule predictions + # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values policy, per_state_preds = model(batch, cond_info[batch_idx]) V = per_state_preds[:, 0] G = rewards[batch_idx] # The return is the terminal reward everywhere, we're using gamma==1 - G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid mol + G = G + (1 - batch.is_valid[batch_idx]) * self.invalid_penalty # Add in penalty for invalid object A = G - V log_probs = policy.log_prob(batch.actions) diff --git a/src/gflownet/algo/envelope_q_learning.py b/src/gflownet/algo/envelope_q_learning.py index 8f836657..9bfc3345 100644 --- a/src/gflownet/algo/envelope_q_learning.py +++ b/src/gflownet/algo/envelope_q_learning.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -163,7 +162,6 @@ def __init__( env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, task: GFNTask, - rng: np.random.RandomState, cfg: Config, ): """Envelope Q-Learning implementation, see @@ -181,15 +179,12 @@ def __init__( A graph environment. ctx: GraphBuildingEnvContext A context. - rng: np.random.RandomState - rng used to take random actions cfg: Config The experiment configuration """ self.ctx = ctx self.env = env self.task = task - self.rng = rng self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes self.illegal_action_logreward = cfg.algo.illegal_action_logreward @@ -204,7 +199,7 @@ def __init__( # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -232,7 +227,7 @@ def create_training_data_from_own_samples( """ dev = get_worker_device() cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) return data def create_training_data_from_graphs(self, graphs): @@ -316,7 +311,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # The position of the last graph of each trajectory final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 - # Forward pass of the model, returns a GraphActionCategorical and per molecule predictions + # Forward pass of the model, returns a GraphActionCategorical and per graph predictions # Here we will interpret the logits of the fwd_cat as Q values # Q(s,a,omega) fwd_cat, per_state_preds = model(batch, cond_info[batch_idx], output_Qs=True) @@ -380,7 +375,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: shifted_Q_pareto = self.gamma * torch.cat([Q_pareto[1:], torch.zeros_like(Q_pareto[:1])], dim=0) # Replace Q(s_T) with R(tau). Since we've shifted the values in the array, Q(s_T) is Q(s_0) # of the next trajectory in the array, and rewards are terminal (0 except at s_T). - shifted_Q_pareto[final_graph_idx] = batch.flat_rewards + ((1 - batch.is_valid) * self.invalid_penalty)[:, None] + shifted_Q_pareto[final_graph_idx] = batch.obj_props + ((1 - batch.is_valid) * self.invalid_penalty)[:, None] y = shifted_Q_pareto.detach() # We now use the same log_prob trick to get Q(s,a,w) diff --git a/src/gflownet/algo/flow_matching.py b/src/gflownet/algo/flow_matching.py index 9c291356..33c436bf 100644 --- a/src/gflownet/algo/flow_matching.py +++ b/src/gflownet/algo/flow_matching.py @@ -1,5 +1,4 @@ import networkx as nx -import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd @@ -39,10 +38,9 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, - rng: np.random.RandomState, cfg: Config, ): - super().__init__(env, ctx, rng, cfg) + super().__init__(env, ctx, cfg) self.fm_epsilon = torch.as_tensor(cfg.algo.fm.epsilon).log() # We include the "balanced loss" as a possibility to reproduce results from the FM paper, but # in a number of settings the regular loss is more stable. diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 2776a22e..9cf9faeb 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -14,6 +14,7 @@ action_type_to_mask, ) from gflownet.models.graph_transformer import GraphTransformerGFN +from gflownet.utils.misc import get_worker_device, get_worker_rng def relabel(g: Graph, ga: GraphAction): @@ -37,7 +38,7 @@ class GraphSampler: """A helper class to sample from GraphActionCategorical-producing models""" def __init__( - self, ctx, env, max_len, max_nodes, rng, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False + self, ctx, env, max_len, max_nodes, sample_temp=1, correct_idempotent=False, pad_with_terminal_state=False ): """ Parameters @@ -50,8 +51,6 @@ def __init__( If not None, ends trajectories of more than max_len steps. max_nodes: int If not None, ends trajectories of graphs with more than max_nodes steps (illegal action). - rng: np.random.RandomState - rng used to take random actions sample_temp: float [Experimental] Softmax temperature used when sampling correct_idempotent: bool @@ -63,16 +62,13 @@ def __init__( self.env = env self.max_len = max_len if max_len is not None else 128 self.max_nodes = max_nodes if max_nodes is not None else 128 - self.rng = rng # Experimental flags self.sample_temp = sample_temp self.sanitize_samples = True self.correct_idempotent = correct_idempotent self.pad_with_terminal_state = pad_with_terminal_state - def sample_from_model( - self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0 - ): + def sample_from_model(self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0): """Samples a model in a minibatch Parameters @@ -83,8 +79,6 @@ def sample_from_model( Number of graphs to sample cond_info: Tensor Conditional information of each trajectory, shape (n, n_info) - dev: torch.device - Device on which data is manipulated random_action_prob: float Probability of taking a random action at each step @@ -97,6 +91,7 @@ def sample_from_model( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ + dev = get_worker_device() # This will be returned data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)] # Let's also keep track of trajectory statistics according to the model @@ -112,6 +107,8 @@ def sample_from_model( # evaluated at s_{t+1} not s_t. bck_a = [[GraphAction(GraphActionType.Stop)] for _ in range(n)] + rng = get_worker_rng() + def not_done(lst): return [e for i, e in enumerate(lst) if not done[i]] @@ -122,11 +119,12 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + ci = cond_info[not_done_mask] if cond_info is not None else None + fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) if random_action_prob > 0: # Device which graphs in the minibatch will get their action randomized is_random_action = torch.tensor( - self.rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev + rng.uniform(size=len(torch_graphs)) < random_action_prob, device=dev ).float() # Set the logits to some large value to have a uniform distribution fwd_cat.logits = [ @@ -172,8 +170,7 @@ def not_done(lst): data[i]["is_sink"].append(0) graphs[i] = gp if done[i] and self.sanitize_samples and not self.ctx.is_sane(graphs[i]): - # check if the graph is sane (e.g. RDKit can - # construct a molecule from it) otherwise + # check if the graph is sane (e.g. RDKit can construct a molecule from it) otherwise # treat the done action as illegal data[i]["is_valid"] = False if all(done): @@ -212,8 +209,7 @@ def sample_backward_from_graphs( self, graphs: List[Graph], model: Optional[nn.Module], - cond_info: Tensor, - dev: torch.device, + cond_info: Optional[Tensor], random_action_prob: float = 0.0, ): """Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution @@ -233,6 +229,7 @@ def sample_backward_from_graphs( Probability of taking a random action (only used if model parameterizes P_B) """ + dev = get_worker_device() n = len(graphs) done = [False] * n data = [ @@ -258,7 +255,8 @@ def not_done(lst): torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] not_done_mask = torch.tensor(done, device=dev).logical_not() if model is not None: - _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + ci = cond_info[not_done_mask] if cond_info is not None else None + _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), ci) else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order diff --git a/src/gflownet/algo/multiobjective_reinforce.py b/src/gflownet/algo/multiobjective_reinforce.py index aa1feef8..b1a636de 100644 --- a/src/gflownet/algo/multiobjective_reinforce.py +++ b/src/gflownet/algo/multiobjective_reinforce.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch_geometric.data as gd from torch_scatter import scatter @@ -17,10 +16,9 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, - rng: np.random.RandomState, cfg: Config, ): - super().__init__(env, ctx, rng, cfg) + super().__init__(env, ctx, cfg) def compute_batch_losses(self, model: TrajectoryBalanceModel, batch: gd.Batch, num_bootstrap: int = 0): """Compute multi objective REINFORCE loss over trajectories contained in the batch""" diff --git a/src/gflownet/algo/soft_q_learning.py b/src/gflownet/algo/soft_q_learning.py index dc205981..a9d61aaa 100644 --- a/src/gflownet/algo/soft_q_learning.py +++ b/src/gflownet/algo/soft_q_learning.py @@ -1,4 +1,3 @@ -import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd @@ -16,7 +15,6 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, - rng: np.random.RandomState, cfg: Config, ): """Soft Q-Learning implementation, see @@ -32,14 +30,11 @@ def __init__( A graph environment. ctx: GraphBuildingEnvContext A context. - rng: np.random.RandomState - rng used to take random actions cfg: Config The experiment configuration """ self.ctx = ctx self.env = env - self.rng = rng self.max_len = cfg.algo.max_len self.max_nodes = cfg.algo.max_nodes self.illegal_action_logreward = cfg.algo.illegal_action_logreward @@ -50,7 +45,7 @@ def __init__( # Experimental flags self.sample_temp = 1 self.do_q_prime_correction = False - self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, rng, self.sample_temp) + self.graph_sampler = GraphSampler(ctx, env, self.max_len, self.max_nodes, self.sample_temp) def create_training_data_from_own_samples( self, model: nn.Module, n: int, cond_info: Tensor, random_action_prob: float @@ -78,7 +73,7 @@ def create_training_data_from_own_samples( """ dev = get_worker_device() cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) + data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) return data def create_training_data_from_graphs(self, graphs): @@ -150,7 +145,7 @@ def compute_batch_losses(self, model: nn.Module, batch: gd.Batch, num_bootstrap: # The position of the last graph of each trajectory final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 - # Forward pass of the model, returns a GraphActionCategorical and per molecule predictions + # Forward pass of the model, returns a GraphActionCategorical and per object predictions # Here we will interpret the logits of the fwd_cat as Q values Q, per_state_preds = model(batch, cond_info[batch_idx]) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index efa25d30..eac57cc6 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -94,9 +94,8 @@ def __init__( self, env: GraphBuildingEnv, ctx: GraphBuildingEnvContext, - rng: np.random.RandomState, cfg: Config, - ): + ) -> None: """Instanciate a TB algorithm. Parameters @@ -105,14 +104,11 @@ def __init__( A graph environment. ctx: GraphBuildingEnvContext A context. - rng: np.random.RandomState - rng used to take random actions cfg: Config Hyperparameters """ self.ctx = ctx self.env = env - self.rng = rng self.global_cfg = cfg self.cfg = cfg.algo.tb self.max_len = cfg.algo.max_len @@ -143,7 +139,6 @@ def __init__( env, cfg.algo.max_len, cfg.algo.max_nodes, - rng, self.sample_temp, correct_idempotent=self.cfg.do_correct_idempotent, pad_with_terminal_state=self.cfg.do_parameterize_p_b, @@ -156,7 +151,11 @@ def set_is_eval(self, is_eval: bool): self.is_eval = is_eval def create_training_data_from_own_samples( - self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float + self, + model: TrajectoryBalanceModel, + n: int, + cond_info: Optional[Tensor] = None, + random_action_prob: Optional[float] = 0.0, ): """Generate trajectories by sampling a model @@ -183,11 +182,12 @@ def create_training_data_from_own_samples( - is_valid: is the generated graph valid according to the env & ctx """ dev = get_worker_device() - cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) - logZ_pred = model.logZ(cond_info) - for i in range(n): - data[i]["logZ"] = logZ_pred[i].item() + cond_info = cond_info.to(dev) if cond_info is not None else None + data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) + if cond_info is not None: + logZ_pred = model.logZ(cond_info) + for i in range(n): + data[i]["logZ"] = logZ_pred[i].item() return data def create_training_data_from_graphs( @@ -220,7 +220,7 @@ def create_training_data_from_graphs( dev = get_worker_device() cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( - graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob + graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: @@ -351,6 +351,9 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.bck_ip_actions = torch.tensor(sum(bck_ipa, [])) batch.bck_ip_lens = torch.tensor([len(i) for i in bck_ipa]) + # compute_batch_losses expects these two optional values, if someone else doesn't fill them in, default to 0 + batch.num_offline = 0 + batch.num_online = 0 return batch def compute_batch_losses( @@ -379,7 +382,7 @@ def compute_batch_losses( clip_log_R = torch.maximum( log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev) ).float() - cond_info = batch.cond_info + cond_info = getattr(batch, "cond_info", None) invalid_mask = 1 - batch.is_valid # This index says which trajectory each graph belongs to, so @@ -393,16 +396,18 @@ def compute_batch_losses( first_graph_idx = shift_right(traj_cumlen) final_graph_idx_1 = torch.maximum(final_graph_idx - 1, first_graph_idx) - fwd_cat: GraphActionCategorical + fwd_cat: GraphActionCategorical # The per-state cond_info + batched_cond_info = cond_info[batch_idx] if cond_info is not None else None + # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) + fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) else: if self.model_is_autoregressive: fwd_cat, per_graph_out = model(batch, cond_info, batched=True) else: - fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) + fwd_cat, per_graph_out = model(batch, batched_cond_info) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] diff --git a/src/gflownet/data/__init__.py b/src/gflownet/data/__init__.py index be5fa80c..e69de29b 100644 --- a/src/gflownet/data/__init__.py +++ b/src/gflownet/data/__init__.py @@ -1,6 +0,0 @@ -from torch.utils.data import Dataset - - -class DatasetWithReward(Dataset): - def flat_reward_transform(self, r): - pass diff --git a/src/gflownet/data/data_source.py b/src/gflownet/data/data_source.py index 782f0750..85ede753 100644 --- a/src/gflownet/data/data_source.py +++ b/src/gflownet/data/data_source.py @@ -204,13 +204,13 @@ def iterator(): def call_sampling_hooks(self, trajs): batch_info = {} # TODO: just pass trajs to the hooks and deprecate passing all those arguments - flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + obj_props = torch.stack([t["obj_props"] for t in trajs]) # convert cond_info back to a dict cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} log_rewards = torch.stack([t["log_reward"] for t in trajs]) rewards = torch.exp(log_rewards / (cond_info.get("beta", 1))) for hook in self.sampling_hooks: - batch_info.update(hook(trajs, rewards, flat_rewards, cond_info)) + batch_info.update(hook(trajs, rewards, obj_props, cond_info)) return batch_info def create_batch(self, trajs, batch_info): @@ -230,26 +230,26 @@ def create_batch(self, trajs, batch_info): log_ns = [self.ctx.traj_log_n(i["traj"]) for i in trajs] batch.log_n = torch.tensor([i[-1] for i in log_ns], dtype=torch.float32) batch.log_ns = torch.tensor(sum(log_ns, start=[]), dtype=torch.float32) - batch.flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + batch.obj_props = torch.stack([t["obj_props"] for t in trajs]) return batch def compute_properties(self, trajs, mark_as_online=False): - """Sets trajs' flat_rewards and is_valid keys by querying the task.""" - # TODO: refactor flat_rewards into properties + """Sets trajs' obj_props and is_valid keys by querying the task.""" + # TODO: refactor obj_props into properties valid_idcs = torch.tensor([i for i in range(len(trajs)) if trajs[i].get("is_valid", True)]).long() # fetch the valid trajectories endpoints - objs = [self.ctx.graph_to_mol(trajs[i]["result"]) for i in valid_idcs] + objs = [self.ctx.graph_to_obj(trajs[i]["result"]) for i in valid_idcs] # ask the task to compute their reward - # TODO: it's really weird that the task is responsible for this and returns a flat_rewards + # TODO: it's really weird that the task is responsible for this and returns a obj_props # tensor whose first dimension is possibly not the same as the output??? - flat_rewards, m_is_valid = self.task.compute_flat_rewards(objs) - assert flat_rewards.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" + obj_props, m_is_valid = self.task.compute_obj_properties(objs) + assert obj_props.ndim == 2, "FlatRewards should be (mbsize, n_objectives), even if n_objectives is 1" # The task may decide some of the objs are invalid, we have to again filter those valid_idcs = valid_idcs[m_is_valid] - all_fr = torch.zeros((len(trajs), flat_rewards.shape[1])) - all_fr[valid_idcs] = flat_rewards + all_fr = torch.zeros((len(trajs), obj_props.shape[1])) + all_fr[valid_idcs] = obj_props for i in range(len(trajs)): - trajs[i]["flat_rewards"] = all_fr[i] + trajs[i]["obj_props"] = all_fr[i] trajs[i]["is_online"] = mark_as_online # Override the is_valid key in case the task made some objs invalid for i in valid_idcs: @@ -257,9 +257,9 @@ def compute_properties(self, trajs, mark_as_online=False): def compute_log_rewards(self, trajs): """Sets trajs' log_reward key by querying the task.""" - flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + obj_props = torch.stack([t["obj_props"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} - log_rewards = self.task.cond_info_to_logreward(cond_info, flat_rewards) + log_rewards = self.task.cond_info_to_logreward(cond_info, obj_props) min_r = torch.as_tensor(self.cfg.algo.illegal_action_logreward).float() for i in range(len(trajs)): trajs[i]["log_reward"] = log_rewards[i] if trajs[i].get("is_valid", True) else min_r @@ -267,7 +267,7 @@ def compute_log_rewards(self, trajs): def send_to_replay(self, trajs): if self.replay_buffer is not None: for t in trajs: - self.replay_buffer.push(t, t["log_reward"], t["flat_rewards"], t["cond_info"], t["is_valid"]) + self.replay_buffer.push(t, t["log_reward"], t["obj_props"], t["cond_info"], t["is_valid"]) def set_traj_cond_info(self, trajs, cond_info): for i in range(len(trajs)): @@ -275,7 +275,7 @@ def set_traj_cond_info(self, trajs, cond_info): def set_traj_props(self, trajs, props): for i in range(len(trajs)): - trajs[i]["flat_rewards"] = props[i] # TODO: refactor + trajs[i]["obj_props"] = props[i] # TODO: refactor def relabel_in_hindsight(self, trajs): if self.cfg.replay.hindsight_ratio == 0: @@ -286,10 +286,10 @@ def relabel_in_hindsight(self, trajs): # samples indexes of trajectories without repeats hindsight_idxs = torch.randperm(len(trajs))[: int(len(trajs) * self.cfg.replay.hindsight_ratio)] log_rewards = torch.stack([t["log_reward"] for t in trajs]) - flat_rewards = torch.stack([t["flat_rewards"] for t in trajs]) + obj_props = torch.stack([t["obj_props"] for t in trajs]) cond_info = {k: torch.stack([t["cond_info"][k] for t in trajs]) for k in trajs[0]["cond_info"]} cond_info, log_rewards = self.task.relabel_condinfo_and_logrewards( - cond_info, log_rewards, flat_rewards, hindsight_idxs + cond_info, log_rewards, obj_props, hindsight_idxs ) self.set_traj_cond_info(trajs, cond_info) for i in range(len(trajs)): diff --git a/src/gflownet/data/qm9.py b/src/gflownet/data/qm9.py index 8fd144c2..9ba71054 100644 --- a/src/gflownet/data/qm9.py +++ b/src/gflownet/data/qm9.py @@ -30,10 +30,10 @@ def __init__(self, h5_file=None, xyz_file=None, train=True, targets=["gap"], spl self.idcs = idcs[: int(np.floor(ratio * len(self.df)))] else: self.idcs = idcs[int(np.floor(ratio * len(self.df))) :] - self.mol_to_graph = lambda x: x + self.obj_to_graph = lambda x: x def setup(self, task, ctx): - self.mol_to_graph = ctx.mol_to_graph + self.obj_to_graph = ctx.obj_to_graph def get_stats(self, target=None, percentile=0.95): if target is None: @@ -46,7 +46,7 @@ def __len__(self): def __getitem__(self, idx): return ( - self.mol_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), + self.obj_to_graph(Chem.MolFromSmiles(self.df["SMILES"][self.idcs[idx]])), torch.tensor([self.df[t][self.idcs[idx]] for t in self.targets]).float(), ) diff --git a/src/gflownet/data/replay_buffer.py b/src/gflownet/data/replay_buffer.py index df14cc68..7fc95024 100644 --- a/src/gflownet/data/replay_buffer.py +++ b/src/gflownet/data/replay_buffer.py @@ -4,10 +4,11 @@ import torch from gflownet.config import Config +from gflownet.utils.misc import get_worker_rng class ReplayBuffer(object): - def __init__(self, cfg: Config, rng: np.random.Generator = None): + def __init__(self, cfg: Config): """ Replay buffer for storing and sampling arbitrary data (e.g. transitions or trajectories) In self.push(), the buffer detaches any torch tensor and sends it to the CPU. @@ -18,7 +19,6 @@ def __init__(self, cfg: Config, rng: np.random.Generator = None): self.buffer: List[tuple] = [] self.position = 0 - self.rng = rng def push(self, *args): if len(self.buffer) == 0: @@ -32,7 +32,7 @@ def push(self, *args): self.position = (self.position + 1) % self.capacity def sample(self, batch_size): - idxs = self.rng.choice(len(self.buffer), batch_size) + idxs = get_worker_rng().choice(len(self.buffer), batch_size) out = list(zip(*[self.buffer[idx] for idx in idxs])) for i in range(len(out)): # stack if all elements are numpy arrays or torch tensors diff --git a/src/gflownet/envs/frag_mol_env.py b/src/gflownet/envs/frag_mol_env.py index 8317e7dc..bac09959 100644 --- a/src/gflownet/envs/frag_mol_env.py +++ b/src/gflownet/envs/frag_mol_env.py @@ -304,15 +304,15 @@ def collate(self, graphs: List[gd.Data]) -> gd.Batch: """ return gd.Batch.from_data_list(graphs, follow_batch=["edge_index"]) - def mol_to_graph(self, mol): + def obj_to_graph(self, mol): """Convert an RDMol to a Graph""" assert type(mol) is Chem.Mol all_matches = {} for fragidx, frag in self.sorted_frags: all_matches[fragidx] = mol.GetSubstructMatches(frag, uniquify=False) - return _recursive_decompose(self, mol, all_matches, {}, [], [], 9) + return _recursive_decompose(self, mol, all_matches, {}, [], [], self.max_frags) - def graph_to_mol(self, g: Graph) -> Chem.Mol: + def graph_to_obj(self, g: Graph) -> Chem.Mol: """Convert a Graph to an RDKit molecule Parameters @@ -361,7 +361,7 @@ def _pop_H(atom): def is_sane(self, g: Graph) -> bool: """Verifies whether the given Graph is valid according to RDKit""" try: - mol = self.graph_to_mol(g) + mol = self.graph_to_obj(g) assert Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is not None except Exception: return False @@ -371,7 +371,7 @@ def is_sane(self, g: Graph) -> bool: def object_to_log_repr(self, g: Graph): """Convert a Graph to a string representation""" - return Chem.MolToSmiles(self.graph_to_mol(g)) + return Chem.MolToSmiles(self.graph_to_obj(g)) def has_n(self) -> bool: return True @@ -477,7 +477,7 @@ def _recursive_decompose(ctx, m, all_matches, a2f, frags, bonds, max_depth=9, nu for a, b, stemidx_a, stemidx_b, _, _ in bonds: g.edges[(a, b)]["src_attach"] = stemidx_a # TODO: verify src/dst is correct? g.edges[(a, b)]["dst_attach"] = stemidx_b - m2 = ctx.graph_to_mol(g) + m2 = ctx.graph_to_obj(g) if m2.HasSubstructMatch(m) and m.HasSubstructMatch(m2): return g return None diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index 3279058d..b10b228d 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -11,7 +11,6 @@ import torch import torch_geometric.data as gd from networkx.algorithms.isomorphism import is_isomorphic -from rdkit.Chem import Mol from torch_scatter import scatter, scatter_max @@ -902,6 +901,7 @@ class GraphBuildingEnvContext: device: torch.device action_type_order: List[GraphActionType] + bck_action_type_order: List[GraphActionType] def ActionIndex_to_GraphAction(self, g: gd.Data, aidx: ActionIndex, fwd: bool = True) -> GraphAction: """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction @@ -986,19 +986,18 @@ def is_sane(self, g: Graph) -> bool: """ raise NotImplementedError() - def mol_to_graph(self, mol: Mol) -> Graph: - """Verifies whether a graph is sane according to the context. This can - catch, e.g. impossible molecules. + def obj_to_graph(self, obj: Any) -> Graph: + """Converts a native object into a generic Graph that the environment can handle Parameters ---------- - mol: Mol - An RDKit molecule + obj: Any + An object Returns ------- g: Graph - The corresponding Graph representation of that molecule. + The corresponding Graph representation of that object. """ raise NotImplementedError() @@ -1008,6 +1007,10 @@ def object_to_log_repr(self, g: Graph) -> str: [[(i, g.nodes[i]) for i in g.nodes], [(e, g.edges[e]) for e in g.edges]], separators=(",", ":") ) + def graph_to_obj(self, g: Graph) -> Any: + """Convert a graph back to an object""" + raise NotImplementedError() + def has_n(self) -> bool: return False diff --git a/src/gflownet/envs/mol_building_env.py b/src/gflownet/envs/mol_building_env.py index 6dc5734c..e24ff8f2 100644 --- a/src/gflownet/envs/mol_building_env.py +++ b/src/gflownet/envs/mol_building_env.py @@ -396,7 +396,7 @@ def collate(self, graphs: List[gd.Data]): """Batch Data instances""" return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"]) - def mol_to_graph(self, mol: Mol) -> Graph: + def obj_to_graph(self, mol: Mol) -> Graph: """Convert an RDMol to a Graph""" g = Graph() mol = Mol(mol) # Make a copy @@ -427,7 +427,7 @@ def mol_to_graph(self, mol: Mol) -> Graph: ) return g - def graph_to_mol(self, g: Graph) -> Mol: + def graph_to_obj(self, g: Graph) -> Mol: mp = Chem.RWMol() mp.BeginBatchEdit() for i in range(len(g.nodes)): @@ -455,7 +455,7 @@ def graph_to_mol(self, g: Graph) -> Mol: def is_sane(self, g: Graph) -> bool: try: - mol = self.graph_to_mol(g) + mol = self.graph_to_obj(g) except Exception: return False if mol is None: @@ -465,7 +465,7 @@ def is_sane(self, g: Graph) -> bool: def object_to_log_repr(self, g: Graph): """Convert a Graph to a string representation""" try: - mol = self.graph_to_mol(g) + mol = self.graph_to_obj(g) assert mol is not None return Chem.MolToSmiles(mol) except Exception: diff --git a/src/gflownet/envs/seq_building_env.py b/src/gflownet/envs/seq_building_env.py index 25c48bad..0e0281a9 100644 --- a/src/gflownet/envs/seq_building_env.py +++ b/src/gflownet/envs/seq_building_env.py @@ -125,9 +125,9 @@ def collate(self, graphs: List[Data]): def is_sane(self, g: Graph) -> bool: return True - def graph_to_mol(self, g: Graph): + def graph_to_obj(self, g: Graph): s: Seq = g # type: ignore return "".join(self.alphabet[int(i)] for i in s.seq) def object_to_log_repr(self, g: Graph): - return self.graph_to_mol(g) + return self.graph_to_obj(g) diff --git a/src/gflownet/envs/test.py b/src/gflownet/envs/test.py index d9c4da4b..9ac5d46b 100644 --- a/src/gflownet/envs/test.py +++ b/src/gflownet/envs/test.py @@ -92,7 +92,7 @@ def main(smi, n_steps): model = Model(ctx, num_emb=64) opt = torch.optim.Adam(model.parameters(), 5e-4) mol = Chem.MolFromSmiles(smi) - molg = ctx.mol_to_graph(mol) + molg = ctx.obj_to_graph(mol) traj = generate_forward_trajectory(molg) for g, a in traj: print(a.action, a.source, a.target, a.value) @@ -140,7 +140,7 @@ def main(smi, n_steps): if not issub: raise ValueError() print(g) - new_mol = ctx.graph_to_mol(g) + new_mol = ctx.graph_to_obj(g) print(Chem.MolToSmiles(new_mol)) # This should be True as well print(new_mol.HasSubstructMatch(mol) and mol.HasSubstructMatch(new_mol)) diff --git a/src/gflownet/models/bengio2021flow.py b/src/gflownet/models/bengio2021flow.py index ae71d74d..2fde211b 100644 --- a/src/gflownet/models/bengio2021flow.py +++ b/src/gflownet/models/bengio2021flow.py @@ -31,103 +31,103 @@ # These are the fragments used in the original paper, each fragment is a tuple # (SMILES string, attachment atom idx). # The attachment atom idx is where bonds between fragments are legal. -FRAGMENTS = [ - ["Br", [0]], - ["C", [0]], - ["C#N", [0]], - ["C1=CCCCC1", [0, 2, 3]], - ["C1=CNC=CC1", [0, 2]], - ["C1CC1", [0]], - ["C1CCCC1", [0]], - ["C1CCCCC1", [0, 1, 2, 3, 4, 5]], - ["C1CCNC1", [0, 2, 3, 4]], - ["C1CCNCC1", [0, 1, 3]], - ["C1CCOC1", [0, 1, 2, 4]], - ["C1CCOCC1", [0, 1, 2, 4, 5]], - ["C1CNCCN1", [2, 5]], - ["C1COCCN1", [5]], - ["C1COCC[NH2+]1", [5]], - ["C=C", [0, 1]], - ["C=C(C)C", [0]], - ["C=CC", [0, 1]], - ["C=N", [0]], - ["C=O", [0]], - ["CC", [0, 1]], - ["CC(C)C", [1]], - ["CC(C)O", [1]], - ["CC(N)=O", [2]], - ["CC=O", [1]], - ["CCC", [1]], - ["CCO", [1]], - ["CN", [0, 1]], - ["CNC", [1]], - ["CNC(C)=O", [0]], - ["CNC=O", [0, 2]], - ["CO", [0, 1]], - ["CS", [0]], - ["C[NH3+]", [0]], - ["C[SH2+]", [1]], - ["Cl", [0]], - ["F", [0]], - ["FC(F)F", [1]], - ["I", [0]], - ["N", [0]], - ["N=CN", [1]], - ["NC=O", [0, 1]], - ["N[SH](=O)=O", [1]], - ["O", [0]], - ["O=CNO", [1]], - ["O=CO", [1]], - ["O=C[O-]", [1]], - ["O=PO", [1]], - ["O=P[O-]", [1]], - ["O=S=O", [1]], - ["O=[NH+][O-]", [1]], - ["O=[PH](O)O", [1]], - ["O=[PH]([O-])O", [1]], - ["O=[SH](=O)O", [1]], - ["O=[SH](=O)[O-]", [1]], - ["O=c1[nH]cnc2[nH]cnc12", [3, 6]], - ["O=c1[nH]cnc2c1NCCN2", [8, 3]], - ["O=c1cc[nH]c(=O)[nH]1", [2, 4]], - ["O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4, 7]], - ["O=c1nccc[nH]1", [3, 6]], - ["S", [0]], - ["c1cc[nH+]cc1", [1, 3]], - ["c1cc[nH]c1", [0, 2]], - ["c1ccc2[nH]ccc2c1", [6]], - ["c1ccc2ccccc2c1", [0, 2]], - ["c1ccccc1", [0, 1, 2, 3, 4, 5]], - ["c1ccncc1", [0, 1, 2, 4, 5]], - ["c1ccsc1", [2, 4]], - ["c1cn[nH]c1", [0, 1, 3, 4]], - ["c1cncnc1", [0, 1, 3, 5]], - ["c1cscn1", [0, 3]], - ["c1ncc2nc[nH]c2n1", [2, 6]], +FRAGMENTS: list[tuple[str, list[int]]] = [ + ("Br", [0]), + ("C", [0]), + ("C#N", [0]), + ("C1=CCCCC1", [0, 2, 3]), + ("C1=CNC=CC1", [0, 2]), + ("C1CC1", [0]), + ("C1CCCC1", [0]), + ("C1CCCCC1", [0, 1, 2, 3, 4, 5]), + ("C1CCNC1", [0, 2, 3, 4]), + ("C1CCNCC1", [0, 1, 3]), + ("C1CCOC1", [0, 1, 2, 4]), + ("C1CCOCC1", [0, 1, 2, 4, 5]), + ("C1CNCCN1", [2, 5]), + ("C1COCCN1", [5]), + ("C1COCC[NH2+]1", [5]), + ("C=C", [0, 1]), + ("C=C(C)C", [0]), + ("C=CC", [0, 1]), + ("C=N", [0]), + ("C=O", [0]), + ("CC", [0, 1]), + ("CC(C)C", [1]), + ("CC(C)O", [1]), + ("CC(N)=O", [2]), + ("CC=O", [1]), + ("CCC", [1]), + ("CCO", [1]), + ("CN", [0, 1]), + ("CNC", [1]), + ("CNC(C)=O", [0]), + ("CNC=O", [0, 2]), + ("CO", [0, 1]), + ("CS", [0]), + ("C[NH3+]", [0]), + ("C[SH2+]", [1]), + ("Cl", [0]), + ("F", [0]), + ("FC(F)F", [1]), + ("I", [0]), + ("N", [0]), + ("N=CN", [1]), + ("NC=O", [0, 1]), + ("N[SH](=O)=O", [1]), + ("O", [0]), + ("O=CNO", [1]), + ("O=CO", [1]), + ("O=C[O-]", [1]), + ("O=PO", [1]), + ("O=P[O-]", [1]), + ("O=S=O", [1]), + ("O=[NH+][O-]", [1]), + ("O=[PH](O)O", [1]), + ("O=[PH]([O-])O", [1]), + ("O=[SH](=O)O", [1]), + ("O=[SH](=O)[O-]", [1]), + ("O=c1[nH]cnc2[nH]cnc12", [3, 6]), + ("O=c1[nH]cnc2c1NCCN2", [8, 3]), + ("O=c1cc[nH]c(=O)[nH]1", [2, 4]), + ("O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4, 7]), + ("O=c1nccc[nH]1", [3, 6]), + ("S", [0]), + ("c1cc[nH+]cc1", [1, 3]), + ("c1cc[nH]c1", [0, 2]), + ("c1ccc2[nH]ccc2c1", [6]), + ("c1ccc2ccccc2c1", [0, 2]), + ("c1ccccc1", [0, 1, 2, 3, 4, 5]), + ("c1ccncc1", [0, 1, 2, 4, 5]), + ("c1ccsc1", [2, 4]), + ("c1cn[nH]c1", [0, 1, 3, 4]), + ("c1cncnc1", [0, 1, 3, 5]), + ("c1cscn1", [0, 3]), + ("c1ncc2nc[nH]c2n1", [2, 6]), ] # 18 fragments from "Towards Understanding and Improving GFlowNet Training" # by Shen et al. (https://arxiv.org/abs/2305.07170) -FRAGMENTS_18 = [ - ["CO", [1, 0]], - ["O=c1[nH]cnc2[nH]cnc12", [3, 6]], - ["S", [0, 0]], - ["C1CNCCN1", [2, 5]], - ["c1cc[nH+]cc1", [3, 1]], - ["c1ccccc1", [0, 2]], - ["C1CCCCC1", [0, 2]], - ["CC(C)C", [1, 2]], - ["C1CCOCC1", [0, 2]], - ["c1cn[nH]c1", [4, 0]], - ["C1CCNC1", [2, 0]], - ["c1cncnc1", [0, 1]], - ["O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4]], - ["c1ccncc1", [1, 0]], - ["O=c1nccc[nH]1", [6, 3]], - ["O=c1cc[nH]c(=O)[nH]1", [2, 4]], - ["C1CCOC1", [2, 4]], - ["C1CCNCC1", [1, 0]], +FRAGMENTS_18: list[tuple[str, list[int]]] = [ + ("CO", [1, 0]), + ("O=c1[nH]cnc2[nH]cnc12", [3, 6]), + ("S", [0, 0]), + ("C1CNCCN1", [2, 5]), + ("c1cc[nH+]cc1", [3, 1]), + ("c1ccccc1", [0, 2]), + ("C1CCCCC1", [0, 2]), + ("CC(C)C", [1, 2]), + ("C1CCOCC1", [0, 2]), + ("c1cn[nH]c1", [4, 0]), + ("C1CCNC1", [2, 0]), + ("c1cncnc1", [0, 1]), + ("O=c1nc2[nH]c3ccccc3nc-2c(=O)[nH]1", [8, 4]), + ("c1ccncc1", [1, 0]), + ("O=c1nccc[nH]1", [6, 3]), + ("O=c1cc[nH]c(=O)[nH]1", [2, 4]), + ("C1CCOC1", [2, 4]), + ("C1CCNCC1", [1, 0]), ] diff --git a/src/gflownet/models/config.py b/src/gflownet/models/config.py index 329c74c5..912ff9f8 100644 --- a/src/gflownet/models/config.py +++ b/src/gflownet/models/config.py @@ -9,6 +9,7 @@ class GraphTransformerConfig(StrictDataClass): num_heads: int = 2 ln_type: str = "pre" num_mlp_layers: int = 0 + concat_heads: bool = True class SeqPosEnc(int, Enum): diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index d0e29e72..366e4390 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -1,5 +1,5 @@ from itertools import chain -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn @@ -33,7 +33,9 @@ class GraphTransformer(nn.Module): node embeddings, and of the final virtual node embeddings. """ - def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre"): + def __init__( + self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, num_noise=0, ln_type="pre", concat=True + ): """ Parameters ---------- @@ -55,6 +57,10 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n ln_type: str The location of Layer Norm in the transformer, either 'pre' or 'post', default 'pre'. (apparently, before is better than after, see https://arxiv.org/pdf/2002.04745.pdf) + concat: bool + Whether each head uses num_emb units (True) or num_emb // num_heads (False) units. Defaults to True. + If True this implies num_emb * num_heads output units within the attention mechanism (which are later + reprojected to num_emb units). """ super().__init__() self.num_layers = num_layers @@ -64,14 +70,15 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2) self.e2h = mlp(e_dim, num_emb, num_emb, 2) - self.c2h = mlp(g_dim, num_emb, num_emb, 2) + self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2) + n_att = num_emb * num_heads if concat else num_emb self.graph2emb = nn.ModuleList( sum( [ [ gnn.GENConv(num_emb, num_emb, num_layers=1, aggr="add", norm=None), - gnn.TransformerConv(num_emb * 2, num_emb, edge_dim=num_emb, heads=num_heads), - nn.Linear(num_heads * num_emb, num_emb), + gnn.TransformerConv(num_emb * 2, n_att // num_heads, edge_dim=num_emb, heads=num_heads), + nn.Linear(n_att, num_emb), gnn.LayerNorm(num_emb, affine=False), mlp(num_emb, num_emb * 4, num_emb, 1), gnn.LayerNorm(num_emb, affine=False), @@ -83,7 +90,7 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n ) ) - def forward(self, g: gd.Batch, cond: torch.Tensor): + def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): """Forward pass Parameters @@ -105,7 +112,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond) + c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node @@ -191,6 +198,7 @@ def __init__( num_layers=cfg.model.num_layers, num_heads=cfg.model.graph_transformer.num_heads, ln_type=cfg.model.graph_transformer.ln_type, + concat=cfg.model.graph_transformer.concat_heads, ) self.env_ctx = env_ctx num_emb = cfg.model.num_emb @@ -231,7 +239,12 @@ def __init__( self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this - self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) + self._logZ = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) + + def logZ(self, cond_info: Optional[torch.Tensor]): + if cond_info is None: + return self._logZ(torch.ones((1, 1), device=self._logZ[0].weight.device)) + return self._logZ(cond_info) def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[GraphActionType]): return GraphActionCategorical( @@ -242,7 +255,7 @@ def _make_cat(self, g: gd.Batch, emb: Dict[str, Tensor], action_types: list[Grap types=action_types, ) - def forward(self, g: gd.Batch, cond: torch.Tensor): + def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): node_embeddings, graph_embeddings = self.transf(g, cond) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index b1a4173a..54557922 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -1,5 +1,6 @@ # This code is adapted from https://github.com/MJ10/mo_gfn import math +from typing import Optional import torch import torch.nn as nn @@ -51,7 +52,7 @@ def __init__( self.embedding = nn.Embedding(env_ctx.num_tokens, num_hid) encoder_layers = nn.TransformerEncoderLayer(num_hid, mc.seq_transformer.num_heads, num_hid, dropout=mc.dropout) self.encoder = nn.TransformerEncoder(encoder_layers, mc.num_layers) - self.logZ = nn.Linear(env_ctx.num_cond_dim, 1) + self._logZ = nn.Linear(env_ctx.num_cond_dim, 1) if self.use_cond: self.output = MLPWithDropout(num_hid + num_hid, num_outs, [4 * num_hid, 4 * num_hid], mc.dropout) self.cond_embed = nn.Linear(env_ctx.num_cond_dim, num_hid) @@ -59,6 +60,11 @@ def __init__( self.output = MLPWithDropout(num_hid, num_outs, [2 * num_hid, 2 * num_hid], mc.dropout) self.num_hid = num_hid + def logZ(self, cond_info: Optional[torch.Tensor]): + if cond_info is None: + return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) + return self._logZ(cond_info) + def forward(self, xs: SeqBatch, cond, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions. diff --git a/src/gflownet/online_trainer.py b/src/gflownet/online_trainer.py index 3ba1fde1..103acc95 100644 --- a/src/gflownet/online_trainer.py +++ b/src/gflownet/online_trainer.py @@ -46,7 +46,7 @@ def setup_algo(self): algo = SoftQLearning else: raise ValueError(algo) - self.algo = algo(self.env, self.ctx, self.rng, self.cfg) + self.algo = algo(self.env, self.ctx, self.cfg) def setup_data(self): self.training_data = [] @@ -71,13 +71,13 @@ def _opt(self, params, lr=None, momentum=None): def setup(self): super().setup() self.offline_ratio = 0 - self.replay_buffer = ReplayBuffer(self.cfg, self.rng) if self.cfg.replay.use else None + self.replay_buffer = ReplayBuffer(self.cfg) if self.cfg.replay.use else None self.sampling_hooks.append(AvgRewardHook()) self.valid_sampling_hooks.append(AvgRewardHook()) # Separate Z parameters from non-Z to allow for LR decay on the former - if hasattr(self.model, "logZ"): - Z_params = list(self.model.logZ.parameters()) + if hasattr(self.model, "_logZ"): + Z_params = list(self.model._logZ.parameters()) non_Z_params = [i for i in self.model.parameters() if all(id(i) != id(j) for j in Z_params)] else: Z_params = [] @@ -103,8 +103,10 @@ def setup(self): }[self.cfg.opt.clip_grad_type] # saving hyperparameters - git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] - self.cfg.git_hash = git_hash + try: + self.cfg.git_hash = git.Repo(__file__, search_parent_directories=True).head.object.hexsha[:7] + except git.InvalidGitRepositoryError: + self.cfg.git_hash = "unknown" # May not have been installed through git yaml_cfg = OmegaConf.to_yaml(self.cfg) if self.print_config: @@ -133,5 +135,5 @@ def step(self, loss: Tensor): class AvgRewardHook: - def __call__(self, trajs, rewards, flat_rewards, extra_info): + def __call__(self, trajs, rewards, obj_props, extra_info): return {"sampled_reward_avg": rewards.mean().item()} diff --git a/src/gflownet/tasks/make_rings.py b/src/gflownet/tasks/make_rings.py index 34f47924..3b87d69c 100644 --- a/src/gflownet/tasks/make_rings.py +++ b/src/gflownet/tasks/make_rings.py @@ -1,13 +1,12 @@ import socket -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple -import numpy as np import torch from rdkit import Chem from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from gflownet import FlatRewards, GFNTask, RewardScalar +from gflownet import GFNTask, LogScalar, ObjectProperties from gflownet.config import Config, init_empty from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.online_trainer import StandardOnlineTrainer @@ -16,25 +15,16 @@ class MakeRingsTask(GFNTask): """A toy task where the reward is the number of rings in the molecule.""" - def __init__( - self, - rng: np.random.Generator, - ): - self.rng = rng - - def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: - return FlatRewards(y) - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: return {"beta": torch.ones(n), "encoding": torch.ones(n, 1)} - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - scalar_logreward = torch.as_tensor(flat_reward).squeeze().clamp(min=1e-30).log() - return RewardScalar(scalar_logreward.flatten()) + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: + scalar_logreward = torch.as_tensor(obj_props).squeeze().clamp(min=1e-30).log() + return LogScalar(scalar_logreward.flatten()) - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: rs = torch.tensor([m.GetRingInfo().NumRings() for m in mols]).float() - return FlatRewards(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() + return ObjectProperties(rs.reshape((-1, 1))), torch.ones(len(mols)).bool() class MakeRingsTrainer(StandardOnlineTrainer): @@ -56,7 +46,7 @@ def set_default_hps(self, cfg: Config): cfg.replay.use = False def setup_task(self): - self.task = MakeRingsTask(rng=self.rng) + self.task = MakeRingsTask() def setup_env_context(self): self.ctx = MolBuildingEnvContext( diff --git a/src/gflownet/tasks/qm9.py b/src/gflownet/tasks/qm9.py index 5f489938..750d879e 100644 --- a/src/gflownet/tasks/qm9.py +++ b/src/gflownet/tasks/qm9.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset import gflownet.models.mxmnet as mxmnet -from gflownet import FlatRewards, GFNTask, RewardScalar +from gflownet import GFNTask, LogScalar, ObjectProperties from gflownet.config import Config, init_empty from gflownet.data.qm9 import QM9Dataset from gflownet.envs.mol_building_env import MolBuildingEnvContext @@ -26,22 +26,20 @@ def __init__( self, dataset: Dataset, cfg: Config, - rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): self._wrap_model = wrap_model - self.rng = rng self.device = get_worker_device() self.models = self.load_task_models(cfg.task.qm9.model_path) self.dataset = dataset - self.temperature_conditional = TemperatureConditional(cfg, rng) + self.temperature_conditional = TemperatureConditional(cfg) self.num_cond_dim = self.temperature_conditional.encoding_size() # TODO: fix interface self._min, self._max, self._percentile_95 = self.dataset.get_stats("gap", percentile=0.05) # type: ignore self._width = self._max - self._min self._rtrans = "unit+95p" # TODO: hyperparameter - def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: + def reward_transform(self, y: Union[float, Tensor]) -> ObjectProperties: """Transforms a target quantity y (e.g. the LUMO energy in QM9) to a positive reward scalar""" if self._rtrans == "exp": flat_r = np.exp(-(y - self._min) / self._width) @@ -52,9 +50,9 @@ def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: flat_r = 1 - (y - self._percentile_95) / self._width else: raise ValueError(self._rtrans) - return FlatRewards(flat_r) + return ObjectProperties(flat_r) - def inverse_flat_reward_transform(self, rp): + def inverse_reward_transform(self, rp): if self._rtrans == "exp": return -np.log(rp) * self._width + self._min elif self._rtrans == "unit": @@ -66,7 +64,7 @@ def load_task_models(self, path): gap_model = mxmnet.MXMNet(mxmnet.Config(128, 6, 5.0)) # TODO: this path should be part of the config? try: - state_dict = torch.load(path) + state_dict = torch.load(path, map_location=self.device) except Exception as e: print( "Could not load model.", @@ -82,8 +80,8 @@ def load_task_models(self, path): def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LogScalar: + return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) @@ -93,7 +91,7 @@ def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: preds = self.models["mxmnet_gap"](batch).reshape((-1,)).data.cpu() / mxmnet.HAR2EV # type: ignore[attr-defined] preds[preds.isnan()] = 1 preds = ( - self.flat_reward_transform(preds) + self.reward_transform(preds) .clip(1e-4, 2) .reshape( -1, @@ -101,15 +99,15 @@ def compute_reward_from_graph(self, graphs: List[gd.Data]) -> Tensor: ) return preds - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] is_valid = torch.tensor([i is not None for i in graphs]).bool() if not is_valid.any(): - return FlatRewards(torch.zeros((0, 1))), is_valid + return ObjectProperties(torch.zeros((0, 1))), is_valid preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) assert len(preds) == is_valid.sum() - return FlatRewards(preds), is_valid + return ObjectProperties(preds), is_valid class QM9GapTrainer(StandardOnlineTrainer): @@ -159,7 +157,6 @@ def setup_task(self): self.task = QM9GapTask( dataset=self.training_data, cfg=self.cfg, - rng=self.rng, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/tasks/qm9_moo.py b/src/gflownet/tasks/qm9_moo.py index 51029e3a..5f97e2c8 100644 --- a/src/gflownet/tasks/qm9_moo.py +++ b/src/gflownet/tasks/qm9_moo.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader, Dataset import gflownet.models.mxmnet as mxmnet -from gflownet import FlatRewards, RewardScalar +from gflownet import LogScalar, ObjectProperties from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config @@ -40,17 +40,16 @@ def __init__( self, dataset: Dataset, cfg: Config, - rng: np.random.Generator = None, wrap_model: Callable[[nn.Module], nn.Module] = None, ): - super().__init__(dataset, cfg, rng, wrap_model) + super().__init__(dataset, cfg, wrap_model) self.cfg = cfg mcfg = self.cfg.task.qm9_moo self.objectives = cfg.task.qm9_moo.objectives cfg.cond.moo.num_objectives = len(self.objectives) self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: - self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid) else: self.focus_cond = None self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) @@ -64,10 +63,10 @@ def __init__( ) assert set(self.objectives) <= {"gap", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) - def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: - return FlatRewards(torch.as_tensor(y)) + def reward_transform(self, y: Union[float, Tensor]) -> ObjectProperties: + return ObjectProperties(torch.as_tensor(y)) - def inverse_flat_reward_transform(self, rp): + def inverse_reward_transform(self, rp): return rp def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: @@ -122,7 +121,7 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor } def relabel_condinfo_and_logrewards( - self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: ObjectProperties, hindsight_idxs: Tensor ): # TODO: we seem to be relabeling tensors in place, could that cause a problem? if self.focus_cond is None: @@ -148,7 +147,7 @@ def relabel_condinfo_and_logrewards( log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) return cond_info, log_rewards - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LogScalar: """ Compute the logreward from the flat_reward and the conditional information """ @@ -161,23 +160,23 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) scalarized_logrewards = to_logreward(scalarized_rewards) focused_logreward = ( - self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) + self.focus_cond.transform(cond_info, (flat_reward, scalarized_logrewards)) if self.focus_cond is not None else scalarized_logrewards ) tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) - return RewardScalar(clamped_logreward) + return LogScalar(clamped_logreward) - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: graphs = [mxmnet.mol2graph(i) for i in mols] # type: ignore[attr-defined] assert len(graphs) == len(mols) is_valid = [i is not None for i in graphs] is_valid_t = torch.tensor(is_valid, dtype=torch.bool) if not any(is_valid): - return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t + return ObjectProperties(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] for obj in self.objectives: @@ -188,7 +187,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: flat_rewards = torch.stack(flat_r, dim=1) assert flat_rewards.shape[0] == is_valid_t.sum() - return FlatRewards(flat_rewards), is_valid_t + return ObjectProperties(flat_rewards), is_valid_t class QM9MOOTrainer(QM9GapTrainer): @@ -204,9 +203,9 @@ def set_default_hps(self, cfg: Config): def setup_algo(self): algo = self.cfg.algo.method if algo == "MOREINFORCE": - self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.cfg) elif algo == "MOQL": - self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.cfg) else: super().setup_algo() @@ -214,7 +213,6 @@ def setup_task(self): self.task = QM9GapMOOTask( dataset=self.training_data, cfg=self.cfg, - rng=self.rng, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index d9da0386..2adca4f4 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,7 +1,6 @@ import socket -from typing import Callable, Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple -import numpy as np import torch import torch.nn as nn import torch_geometric.data as gd @@ -11,7 +10,7 @@ from torch.utils.data import Dataset from torch_geometric.data import Data -from gflownet import FlatRewards, GFNTask, RewardScalar +from gflownet import GFNTask, LogScalar, ObjectProperties from gflownet.config import Config, init_empty from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext, Graph from gflownet.models import bengio2021flow @@ -33,24 +32,14 @@ class SEHTask(GFNTask): def __init__( self, - dataset: Dataset, cfg: Config, - rng: np.random.Generator = None, - wrap_model: Callable[[nn.Module], nn.Module] = None, - ): - self._wrap_model = wrap_model - self.rng = rng + wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None, + ) -> None: + self._wrap_model = wrap_model if wrap_model is not None else (lambda x: x) self.models = self._load_task_models() - self.dataset = dataset - self.temperature_conditional = TemperatureConditional(cfg, rng) + self.temperature_conditional = TemperatureConditional(cfg) self.num_cond_dim = self.temperature_conditional.encoding_size() - def flat_reward_transform(self, y: Union[float, Tensor]) -> FlatRewards: - return FlatRewards(torch.as_tensor(y) / 8) - - def inverse_flat_reward_transform(self, rp): - return rp * 8 - def _load_task_models(self): model = bengio2021flow.load_original_model() model.to(get_worker_device()) @@ -60,25 +49,25 @@ def _load_task_models(self): def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LogScalar: + return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) def compute_reward_from_graph(self, graphs: List[Data]) -> Tensor: batch = gd.Batch.from_data_list([i for i in graphs if i is not None]) batch.to(self.models["seh"].device if hasattr(self.models["seh"], "device") else get_worker_device()) - preds = self.models["seh"](batch).reshape((-1,)).data.cpu() + preds = self.models["seh"](batch).reshape((-1,)).data.cpu() / 8 preds[preds.isnan()] = 0 - return self.flat_reward_transform(preds).clip(1e-4, 100).reshape((-1,)) + return preds.clip(1e-4, 100).reshape((-1,)) - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] is_valid = torch.tensor([i is not None for i in graphs]).bool() if not is_valid.any(): - return FlatRewards(torch.zeros((0, 1))), is_valid + return ObjectProperties(torch.zeros((0, 1))), is_valid preds = self.compute_reward_from_graph(graphs).reshape((-1, 1)) assert len(preds) == is_valid.sum() - return FlatRewards(preds), is_valid + return ObjectProperties(preds), is_valid SOME_MOLS = [ @@ -117,14 +106,14 @@ class LittleSEHDataset(Dataset): def __init__(self, smis) -> None: super().__init__() - self.props: List[Tensor] = [] + self.props: ObjectProperties self.mols: List[Graph] = [] self.smis = smis - def setup(self, task, ctx): + def setup(self, task: SEHTask, ctx: FragMolBuildingEnvContext) -> None: rdmols = [Chem.MolFromSmiles(i) for i in SOME_MOLS] - self.mols = [ctx.mol_to_graph(i) for i in rdmols] - self.props = task.compute_flat_rewards(rdmols)[0] + self.mols = [ctx.obj_to_graph(i) for i in rdmols] + self.props = task.compute_obj_properties(rdmols)[0] def __len__(self): return len(self.mols) @@ -173,9 +162,7 @@ def set_default_hps(self, cfg: Config): def setup_task(self): self.task = SEHTask( - dataset=self.training_data, cfg=self.cfg, - rng=self.rng, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index d3880b5c..c4f0b67f 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch @@ -8,9 +8,9 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader -from gflownet import FlatRewards, RewardScalar +from gflownet import LogScalar, ObjectProperties from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL from gflownet.algo.multiobjective_reinforce import MultiObjectiveReinforce from gflownet.config import Config, init_empty @@ -63,19 +63,16 @@ class SEHMOOTask(SEHTask): def __init__( self, - dataset: Dataset, cfg: Config, - rng: np.random.Generator = None, - wrap_model: Callable[[nn.Module], nn.Module] = None, + wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(dataset, cfg, rng, wrap_model) + super().__init__(cfg, wrap_model) self.cfg = cfg mcfg = self.cfg.task.seh_moo self.objectives = cfg.task.seh_moo.objectives cfg.cond.moo.num_objectives = len(self.objectives) # This value is used by the focus_cond and pref_cond - self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: - self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid, rng) + self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid) else: self.focus_cond = None self.pref_cond = MultiObjectiveWeightedPreferences(self.cfg) @@ -89,9 +86,6 @@ def __init__( ) assert set(self.objectives) <= {"seh", "qed", "sa", "mw"} and len(self.objectives) == len(set(self.objectives)) - def inverse_flat_reward_transform(self, rp): - return rp - def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: cond_info = super().sample_conditional_information(n, train_it) pref_ci = self.pref_cond.sample(n) @@ -144,7 +138,7 @@ def encode_conditional_information(self, steer_info: Tensor) -> Dict[str, Tensor } def relabel_condinfo_and_logrewards( - self, cond_info: Dict[str, Tensor], log_rewards: Tensor, flat_rewards: FlatRewards, hindsight_idxs: Tensor + self, cond_info: Dict[str, Tensor], log_rewards: Tensor, obj_props: ObjectProperties, hindsight_idxs: Tensor ): # TODO: we seem to be relabeling tensors in place, could that cause a problem? if self.focus_cond is None: @@ -153,13 +147,13 @@ def relabel_condinfo_and_logrewards( return cond_info, log_rewards # only keep hindsight_idxs that actually correspond to a violated constraint _, in_focus_mask = metrics.compute_focus_coef( - flat_rewards, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim + obj_props, cond_info["focus_dir"], self.focus_cond.cfg.focus_cosim ) out_focus_mask = torch.logical_not(in_focus_mask) hindsight_idxs = hindsight_idxs[out_focus_mask[hindsight_idxs]] # relabels the focus_dirs and log_rewards - cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(flat_rewards[hindsight_idxs], dim=1) + cond_info["focus_dir"][hindsight_idxs] = nn.functional.normalize(obj_props[hindsight_idxs], dim=1) preferences_enc = self.pref_cond.encode(cond_info["preferences"]) focus_enc = self.focus_cond.encode(cond_info["focus_dir"]) @@ -167,13 +161,15 @@ def relabel_condinfo_and_logrewards( [cond_info["encoding"][:, : self.num_thermometer_dim], preferences_enc, focus_enc], 1 ) - log_rewards = self.cond_info_to_logreward(cond_info, flat_rewards) + log_rewards = self.cond_info_to_logreward(cond_info, obj_props) return cond_info, log_rewards - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: """ - Compute the logreward from the flat_reward and the conditional information + Compute the logreward from the object properties, which we interpret as each objective, and the conditional + information """ + flat_reward = obj_props if isinstance(flat_reward, list): if isinstance(flat_reward[0], Tensor): flat_reward = torch.stack(flat_reward) @@ -183,22 +179,22 @@ def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: Flat scalarized_rewards = self.pref_cond.transform(cond_info, flat_reward) scalarized_logrewards = to_logreward(scalarized_rewards) focused_logreward = ( - self.focus_cond.transform(cond_info, flat_reward, scalarized_logrewards) + self.focus_cond.transform(cond_info, (flat_reward, scalarized_logrewards)) if self.focus_cond is not None else scalarized_logrewards ) tempered_logreward = self.temperature_conditional.transform(cond_info, focused_logreward) clamped_logreward = tempered_logreward.clamp(min=self.cfg.algo.illegal_action_logreward) - return RewardScalar(clamped_logreward) + return LogScalar(clamped_logreward) - def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, mols: List[RDMol]) -> Tuple[ObjectProperties, Tensor]: graphs = [bengio2021flow.mol2graph(i) for i in mols] assert len(graphs) == len(mols) is_valid = [i is not None for i in graphs] is_valid_t = torch.tensor(is_valid, dtype=torch.bool) if not any(is_valid): - return FlatRewards(torch.zeros((0, len(self.objectives)))), is_valid_t + return ObjectProperties(torch.zeros((0, len(self.objectives)))), is_valid_t else: flat_r: List[Tensor] = [] for obj in self.objectives: @@ -209,7 +205,7 @@ def compute_flat_rewards(self, mols: List[RDMol]) -> Tuple[FlatRewards, Tensor]: flat_rewards = torch.stack(flat_r, dim=1) assert flat_rewards.shape[0] == len(mols) - return FlatRewards(flat_rewards), is_valid_t + return ObjectProperties(flat_rewards), is_valid_t class SEHMOOFragTrainer(SEHFragTrainer): @@ -227,18 +223,16 @@ def set_default_hps(self, cfg: Config): def setup_algo(self): algo = self.cfg.algo.method if algo == "MOREINFORCE": - self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.rng, self.cfg) + self.algo = MultiObjectiveReinforce(self.env, self.ctx, self.cfg) elif algo == "MOQL": - self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.rng, self.cfg) + self.algo = EnvelopeQLearning(self.env, self.ctx, self.task, self.cfg) else: super().setup_algo() def setup_task(self): self.cfg.cond.moo.num_objectives = len(self.cfg.task.seh_moo.objectives) self.task = SEHMOOTask( - dataset=self.training_data, cfg=self.cfg, - rng=self.rng, wrap_model=self._wrap_for_mp, ) diff --git a/src/gflownet/tasks/toy_seq.py b/src/gflownet/tasks/toy_seq.py index f2c75f60..a215948b 100644 --- a/src/gflownet/tasks/toy_seq.py +++ b/src/gflownet/tasks/toy_seq.py @@ -1,11 +1,10 @@ import socket from typing import Dict, List, Tuple -import numpy as np import torch from torch import Tensor -from gflownet import FlatRewards, GFNTask, RewardScalar +from gflownet import GFNTask, LogScalar, ObjectProperties from gflownet.config import Config, init_empty from gflownet.envs.seq_building_env import AutoregressiveSeqBuildingContext, SeqBuildingEnv from gflownet.models.seq_transformer import SeqTransformerGFN @@ -22,22 +21,21 @@ def __init__( self, seqs: List[str], cfg: Config, - rng: np.random.Generator, - ): + ) -> None: self.seqs = seqs - self.temperature_conditional = TemperatureConditional(cfg, rng) + self.temperature_conditional = TemperatureConditional(cfg) self.num_cond_dim = self.temperature_conditional.encoding_size() self.norm = cfg.algo.max_len / min(map(len, seqs)) def sample_conditional_information(self, n: int, train_it: int) -> Dict[str, Tensor]: return self.temperature_conditional.sample(n) - def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: - return RewardScalar(self.temperature_conditional.transform(cond_info, to_logreward(flat_reward))) + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], obj_props: ObjectProperties) -> LogScalar: + return LogScalar(self.temperature_conditional.transform(cond_info, to_logreward(obj_props))) - def compute_flat_rewards(self, objs: List[str]) -> Tuple[FlatRewards, Tensor]: + def compute_obj_properties(self, objs: List[str]) -> Tuple[ObjectProperties, Tensor]: rs = torch.tensor([sum([s.count(p) for p in self.seqs]) for s in objs]).float() / self.norm - return FlatRewards(rs[:, None]), torch.ones(len(objs), dtype=torch.bool) + return ObjectProperties(rs[:, None]), torch.ones(len(objs), dtype=torch.bool) class ToySeqTrainer(StandardOnlineTrainer): @@ -47,6 +45,7 @@ def set_default_hps(self, cfg: Config): cfg.hostname = socket.gethostname() cfg.pickle_mp_messages = False cfg.num_workers = 8 + cfg.num_validation_gen_steps = 1 cfg.opt.learning_rate = 1e-4 cfg.opt.weight_decay = 1e-8 cfg.opt.momentum = 0.9 @@ -81,7 +80,6 @@ def setup_task(self): self.task = ToySeqTask( ["aa", "bb", "cc"], cfg=self.cfg, - rng=self.rng, ) def setup_env_context(self): diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 894082f0..386c0494 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -1,4 +1,5 @@ import gc +import logging import os import pathlib import shutil @@ -21,7 +22,7 @@ from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphActionCategorical, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.seq_building_env import SeqBatch -from gflownet.utils.misc import create_logger, set_main_process_device +from gflownet.utils.misc import create_logger, set_main_process_device, set_worker_rng_seed from gflownet.utils.multiprocessing_proxy import mp_object_wrapper from gflownet.utils.sqlite_log import SQLiteLogHook @@ -114,7 +115,7 @@ def setup(self): os.makedirs(self.cfg.log_dir) RDLogger.DisableLog("rdApp.*") - self.rng = np.random.default_rng(142857) + set_worker_rng_seed(self.cfg.seed) self.env = GraphBuildingEnv() self.setup_data() self.setup_task() @@ -318,7 +319,7 @@ def run(self, logger=None): v = v.item() final_info[k].append(v) if it % self.print_every == 0: - logger.info(f"Generating mols {it - num_training_steps}/{num_final_gen_steps}") + logger.info(f"Generating objs {it - num_training_steps}/{num_final_gen_steps}") final_info = {k: np.mean(v) for k, v in final_info.items()} logger.info("Final generation steps completed - " + " ".join(f"{k}:{v:.2f}" for k, v in final_info.items())) @@ -331,6 +332,10 @@ def run(self, logger=None): del final_dl def terminate(self): + logger = logging.getLogger("logger") + for handler in logger.handlers: + handler.close() + for hook in self.sampling_hooks: if hasattr(hook, "terminate") and hook.terminate not in self.to_terminate: hook.terminate() diff --git a/src/gflownet/utils/conditioning.py b/src/gflownet/utils/conditioning.py index 0630be55..952dae77 100644 --- a/src/gflownet/utils/conditioning.py +++ b/src/gflownet/utils/conditioning.py @@ -1,6 +1,6 @@ import abc from copy import deepcopy -from typing import Dict +from typing import Dict, Generic, Optional, TypeVar import numpy as np import torch @@ -9,19 +9,23 @@ from torch.distributions.dirichlet import Dirichlet from torch_geometric import data as gd +from gflownet import LinScalar, LogScalar, ObjectProperties from gflownet.config import Config from gflownet.utils import metrics from gflownet.utils.focus_model import TabularFocusModel -from gflownet.utils.misc import get_worker_device +from gflownet.utils.misc import get_worker_device, get_worker_rng from gflownet.utils.transforms import thermometer +Tin = TypeVar("Tin") +Tout = TypeVar("Tout") -class Conditional(abc.ABC): + +class Conditional(abc.ABC, Generic[Tin, Tout]): def sample(self, n): raise NotImplementedError() @abc.abstractmethod - def transform(self, cond_info: Dict[str, Tensor], properties: Tensor) -> Tensor: + def transform(self, cond_info: Dict[str, Tensor], data: Tin) -> Tout: raise NotImplementedError() def encoding_size(self): @@ -31,11 +35,10 @@ def encode(self, conditional: Tensor) -> Tensor: raise NotImplementedError() -class TemperatureConditional(Conditional): - def __init__(self, cfg: Config, rng: np.random.Generator): +class TemperatureConditional(Conditional[LogScalar, LogScalar]): + def __init__(self, cfg: Config): self.cfg = cfg tmp_cfg = self.cfg.cond.temperature - self.rng = rng self.upper_bound = 1024 if tmp_cfg.sample_dist == "gamma": loc, scale = tmp_cfg.dist_params @@ -53,6 +56,7 @@ def encoding_size(self): def sample(self, n): cfg = self.cfg.cond.temperature beta = None + rng = get_worker_rng() if cfg.sample_dist == "constant": if isinstance(cfg.dist_params[0], (float, int, np.int64, np.int32)): beta = np.array(cfg.dist_params[0]).repeat(n).astype(np.float32) @@ -62,26 +66,26 @@ def sample(self, n): else: if cfg.sample_dist == "gamma": loc, scale = cfg.dist_params - beta = self.rng.gamma(loc, scale, n).astype(np.float32) + beta = rng.gamma(loc, scale, n).astype(np.float32) elif cfg.sample_dist == "uniform": a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) - beta = self.rng.uniform(a, b, n).astype(np.float32) + beta = rng.uniform(a, b, n).astype(np.float32) elif cfg.sample_dist == "loguniform": low, high = np.log(cfg.dist_params) - beta = np.exp(self.rng.uniform(low, high, n).astype(np.float32)) + beta = np.exp(rng.uniform(low, high, n).astype(np.float32)) elif cfg.sample_dist == "beta": a, b = float(cfg.dist_params[0]), float(cfg.dist_params[1]) - beta = self.rng.beta(a, b, n).astype(np.float32) + beta = rng.beta(a, b, n).astype(np.float32) beta_enc = thermometer(torch.tensor(beta), cfg.num_thermometer_dim, 0, self.upper_bound) assert len(beta.shape) == 1, f"beta should be a 1D array, got {beta.shape}" return {"beta": torch.tensor(beta), "encoding": beta_enc} - def transform(self, cond_info: Dict[str, Tensor], logreward: Tensor) -> Tensor: + def transform(self, cond_info: Dict[str, Tensor], logreward: LogScalar) -> LogScalar: assert len(logreward.shape) == len( cond_info["beta"].shape ), f"dangerous shape mismatch: {logreward.shape} vs {cond_info['beta'].shape}" - return logreward * cond_info["beta"] + return LogScalar(logreward * cond_info["beta"]) def encode(self, conditional: Tensor) -> Tensor: cfg = self.cfg.cond.temperature @@ -90,7 +94,7 @@ def encode(self, conditional: Tensor) -> Tensor: return thermometer(torch.tensor(conditional), cfg.num_thermometer_dim, 0, self.upper_bound) -class MultiObjectiveWeightedPreferences(Conditional): +class MultiObjectiveWeightedPreferences(Conditional[ObjectProperties, LinScalar]): def __init__(self, cfg: Config): self.cfg = cfg.cond.weighted_prefs self.num_objectives = cfg.cond.moo.num_objectives @@ -115,10 +119,10 @@ def sample(self, n): preferences = torch.as_tensor(preferences).float() return {"preferences": preferences, "encoding": self.encode(preferences)} - def transform(self, cond_info: Dict[str, Tensor], flat_reward: Tensor) -> Tensor: + def transform(self, cond_info: Dict[str, Tensor], flat_reward: ObjectProperties) -> LinScalar: scalar_reward = (flat_reward * cond_info["preferences"]).sum(1) assert len(scalar_reward.shape) == 1, f"scalar_reward should be a 1D array, got {scalar_reward.shape}" - return scalar_reward + return LinScalar(scalar_reward) def encoding_size(self): return max(1, self.num_thermometer_dim * self.num_objectives) @@ -130,13 +134,12 @@ def encode(self, conditional: Tensor) -> Tensor: return conditional.unsqueeze(1) -class FocusRegionConditional(Conditional): - def __init__(self, cfg: Config, n_valid: int, rng: np.random.Generator): +class FocusRegionConditional(Conditional[tuple[ObjectProperties, LogScalar], LogScalar]): + def __init__(self, cfg: Config, n_valid: int): self.cfg = cfg.cond.focus_region self.n_valid = n_valid self.n_objectives = cfg.cond.moo.num_objectives self.ocfg = cfg - self.rng = rng self.num_thermometer_dim = cfg.cond.moo.num_thermometer_dim if self.cfg.use_steer_thermomether else 0 focus_type = self.cfg.focus_type @@ -189,15 +192,16 @@ def setup_focus_regions(self): ) self.valid_focus_dirs = valid_focus_dirs - def sample(self, n: int, train_it: int = None): + def sample(self, n: int, train_it: Optional[int] = None): train_it = train_it or 0 + rng = get_worker_rng() if self.fixed_focus_dirs is not None: focus_dir = torch.tensor( - np.array(self.fixed_focus_dirs)[self.rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) + np.array(self.fixed_focus_dirs)[rng.choice(len(self.fixed_focus_dirs), n)].astype(np.float32) ) elif self.cfg.focus_type == "dirichlet": m = Dirichlet(torch.FloatTensor([1.0] * self.n_objectives)) - focus_dir = m.sample([n]) + focus_dir = m.sample(torch.Size((n,))) elif self.cfg.focus_type == "hyperspherical": focus_dir = torch.tensor( metrics.sample_positiveQuadrant_ndim_sphere(n, self.n_objectives, normalisation="l2") @@ -224,11 +228,12 @@ def encode(self, conditional: Tensor) -> Tensor: else conditional ) - def transform(self, cond_info: Dict[str, Tensor], flat_rewards: Tensor, scalar_logreward: Tensor = None) -> Tensor: + def transform(self, cond_info: Dict[str, Tensor], data: tuple[ObjectProperties, LogScalar]) -> LogScalar: + flat_rewards, scalar_logreward = data focus_coef, in_focus_mask = metrics.compute_focus_coef( flat_rewards, cond_info["focus_dir"], self.cfg.focus_cosim, self.cfg.focus_limit_coef ) - scalar_logreward = scalar_logreward.clone() # Avoid modifying the original tensor + scalar_logreward = LogScalar(scalar_logreward.clone()) # Avoid modifying the original tensor scalar_logreward[in_focus_mask] += torch.log(focus_coef[in_focus_mask]) scalar_logreward[~in_focus_mask] = self.ocfg.algo.illegal_action_logreward diff --git a/src/gflownet/utils/sqlite_log.py b/src/gflownet/utils/sqlite_log.py index 1ac183db..ae544ec5 100644 --- a/src/gflownet/utils/sqlite_log.py +++ b/src/gflownet/utils/sqlite_log.py @@ -12,29 +12,29 @@ def __init__(self, log_dir, ctx) -> None: self.ctx = ctx self.data_labels = None - def __call__(self, trajs, rewards, flat_rewards, cond_info): + def __call__(self, trajs, rewards, obj_props, cond_info): if self.log is None: worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 os.makedirs(self.log_dir, exist_ok=True) - self.log_path = f"{self.log_dir}/generated_mols_{self._wid}.db" + self.log_path = f"{self.log_dir}/generated_objs_{self._wid}.db" self.log = SQLiteLog() self.log.connect(self.log_path) if hasattr(self.ctx, "object_to_log_repr"): - mols = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] + objs = [self.ctx.object_to_log_repr(t["result"]) if t["is_valid"] else "" for t in trajs] else: - mols = [""] * len(trajs) + objs = [""] * len(trajs) - flat_rewards = flat_rewards.reshape((len(flat_rewards), -1)).data.numpy().tolist() + obj_props = obj_props.reshape((len(obj_props), -1)).data.numpy().tolist() rewards = rewards.data.numpy().tolist() - preferences = cond_info.get("preferences", torch.zeros((len(mols), 0))).data.numpy().tolist() - focus_dir = cond_info.get("focus_dir", torch.zeros((len(mols), 0))).data.numpy().tolist() + preferences = cond_info.get("preferences", torch.zeros((len(objs), 0))).data.numpy().tolist() + focus_dir = cond_info.get("focus_dir", torch.zeros((len(objs), 0))).data.numpy().tolist() logged_keys = [k for k in sorted(cond_info.keys()) if k not in ["encoding", "preferences", "focus_dir"]] data = [ - [mols[i], rewards[i]] - + flat_rewards[i] + [objs[i], rewards[i]] + + obj_props[i] + preferences[i] + focus_dir[i] + [cond_info[k][i].item() for k in logged_keys] @@ -43,7 +43,7 @@ def __call__(self, trajs, rewards, flat_rewards, cond_info): if self.data_labels is None: self.data_labels = ( ["smi", "r"] - + [f"fr_{i}" for i in range(len(flat_rewards[0]))] + + [f"fr_{i}" for i in range(len(obj_props[0]))] + [f"pref_{i}" for i in range(len(preferences[0]))] + [f"focus_{i}" for i in range(len(focus_dir[0]))] + [f"ci_{k}" for k in logged_keys] @@ -93,3 +93,19 @@ def insert_many(self, rows, column_names): cur.executemany(f'insert into results values ({",".join("?"*len(rows[0]))})', rows) # nosec cur.close() self.db.commit() + + def __del__(self): + if self.db is not None: + self.db.close() + + +def read_all_results(path): + # E402: module level import not at top of file, but pandas is an optional dependency + import pandas as pd # noqa: E402 + + num_workers = len([f for f in os.listdir(path) if f.startswith("generated_objs")]) + dfs = [ + pd.read_sql_query("SELECT * FROM results", sqlite3.connect(f"file:{path}/generated_objs_{i}.db?mode=ro")) + for i in range(num_workers) + ] + return pd.concat(dfs).sort_index().reset_index(drop=True) diff --git a/src/gflownet/utils/transforms.py b/src/gflownet/utils/transforms.py index f5428ddc..d1c70793 100644 --- a/src/gflownet/utils/transforms.py +++ b/src/gflownet/utils/transforms.py @@ -1,9 +1,11 @@ import torch from torch import Tensor +from gflownet import LogScalar -def to_logreward(reward: Tensor) -> Tensor: - return reward.squeeze().clamp(min=1e-30).log() + +def to_logreward(reward: Tensor) -> LogScalar: + return LogScalar(reward.squeeze().clamp(min=1e-30).log()) def thermometer(v: Tensor, n_bins: int = 50, vmin: float = 0, vmax: float = 1) -> Tensor: diff --git a/tests/test_envs.py b/tests/test_envs.py index b371e03f..8838469b 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,20 +3,19 @@ import networkx as nx import pytest -from omegaconf import OmegaConf from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext -from gflownet.envs.graph_building_env import ActionIndex, GraphBuildingEnv +from gflownet.envs.graph_building_env import ActionIndex, GraphBuildingEnv, GraphBuildingEnvContext from gflownet.envs.mol_building_env import MolBuildingEnvContext from gflownet.models import bengio2021flow -def build_two_node_states(ctx): +def build_two_node_states(ctx: GraphBuildingEnvContext): # TODO: This is actually fairly generic code that will probably be reused by other tests in the future. # Having a proper class to handle graph-indexed hash maps would probably be good. - graph_cache = {} + graph_cache: dict[str, nx.Graph] = {} graph_by_idx = {} _graph_cache_buckets = {} @@ -74,11 +73,11 @@ def expand(s, idx): return [graph_by_idx[i] for i in list(nx.topological_sort(mdp_graph))] -def get_frag_env_ctx(): +def get_frag_env_ctx() -> FragMolBuildingEnvContext: return FragMolBuildingEnvContext(max_frags=2, fragments=bengio2021flow.FRAGMENTS[:20]) -def get_atom_env_ctx(): +def get_atom_env_ctx() -> MolBuildingEnvContext: return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) @@ -106,7 +105,7 @@ def two_node_states_atoms(request): return data -def _test_backwards_action_mask_equivalence(two_node_states, ctx): +def _test_backwards_action_mask_equivalence(two_node_states: list[nx.Graph], ctx: GraphBuildingEnvContext) -> None: """This tests that FragMolBuildingEnvContext implements backwards action masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -124,7 +123,7 @@ def _test_backwards_action_mask_equivalence(two_node_states, ctx): raise ValueError() -def _test_backwards_action_mask_equivalence_ipa(two_node_states, ctx): +def _test_backwards_action_mask_equivalence_ipa(two_node_states: list[nx.Graph], ctx: GraphBuildingEnvContext) -> None: """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is a different number of actions leading to the parents of any state. @@ -132,16 +131,16 @@ def _test_backwards_action_mask_equivalence_ipa(two_node_states, ctx): This test also accounts for idempotent actions. """ env = GraphBuildingEnv() - cfg = OmegaConf.structured(Config) + cfg = Config() cfg.algo.max_nodes = 2 - algo = TrajectoryBalance(env, ctx, None, cfg) + algo = TrajectoryBalance(env, ctx, cfg) for i in range(1, len(two_node_states)): g = two_node_states[i] n = env.count_backward_transitions(g, check_idempotent=True) gd = ctx.graph_to_Data(g) # To check that we're computing masks correctly, we need to check that there is the same # number of idempotent action classes, i.e. groups of actions that lead to the same parent. - equivalence_classes = [] + equivalence_classes: list[list[tuple[int, int, int]]] = [] for u, k in enumerate(ctx.bck_action_type_order): m = getattr(gd, k.mask_name) for aidx in m.nonzero():