From 71c0d0d377ddf9479223dfbcda6added515e28ac Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 13 Mar 2024 09:22:42 -0600 Subject: [PATCH] some renaming and None-proofing of cond_info --- docs/contributing.md | 34 ++++++++++++++++++++++++ docs/getting_started.md | 2 ++ src/gflownet/algo/graph_sampling.py | 17 ++++++------ src/gflownet/algo/trajectory_balance.py | 30 ++++++++++++++------- src/gflownet/models/graph_transformer.py | 16 +++++++---- src/gflownet/models/mxmnet.py | 1 - src/gflownet/tasks/seh_frag.py | 8 +++--- src/gflownet/tasks/seh_frag_moo.py | 10 +++---- 8 files changed, 83 insertions(+), 35 deletions(-) diff --git a/docs/contributing.md b/docs/contributing.md index e69de29b..a85e945c 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -0,0 +1,34 @@ +# Contributing + +Contributions to the repository are welcome, and we encourage you to open issues and pull requests. In general, it is recommended to fork this repository and open a pull request from your fork to the `trunk` branch. PRs are encouraged to be short and focused, and to include tests and documentation where appropriate. + +## Installation + +To install the developers dependencies run: +``` +pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html +``` + +## Dependencies + +Dependencies are defined in `pyproject.toml`, and frozen versions that are known to work are provided in `requirements/`. + +To regenerate the frozen versions, run `./generate_requirements.sh `. See comments within. + +## Linting and testing + +We use `tox` to run tests and linting, and `pre-commit` to run checks before committing. +To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively. + +`tox` itself runs many linters, but the most important ones are `black`, `ruff`, `isort`, and `mypy`. The full list +of linting tools is found in `.pre-commit-config.yaml`, while `tox.ini` defines the environments under which these +linters (as well as tests) are run. + +## Github Actions + +We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`. + +## Style Guide + +On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html). +Our docstrings follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) format, and we use type hints throughout the codebase. diff --git a/docs/getting_started.md b/docs/getting_started.md index e69de29b..5ff2d7e5 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -0,0 +1,2 @@ +# Getting Started + diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 7bb4674c..ebc7e48a 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -8,7 +8,7 @@ from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType from gflownet.models.graph_transformer import GraphTransformerGFN -from gflownet.utils.misc import get_worker_rng +from gflownet.utils.misc import get_worker_rng, get_worker_device def relabel(g: Graph, ga: GraphAction): @@ -63,7 +63,7 @@ def __init__( self.pad_with_terminal_state = pad_with_terminal_state def sample_from_model( - self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0 + self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0 ): """Samples a model in a minibatch @@ -75,8 +75,6 @@ def sample_from_model( Number of graphs to sample cond_info: Tensor Conditional information of each trajectory, shape (n, n_info) - dev: torch.device - Device on which data is manipulated Returns ------- @@ -87,6 +85,7 @@ def sample_from_model( - bck_logprob: sum logprobs P_B - is_valid: is the generated graph valid according to the env & ctx """ + dev = get_worker_device() # This will be returned data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)] # Let's also keep track of trajectory statistics according to the model @@ -114,7 +113,8 @@ def not_done(lst): # Forward pass to get GraphActionCategorical # Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does. # TODO: compute bck_cat.log_prob(bck_a) when relevant - fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + ci = cond_info[not_done_mask] if cond_info is not None else None + fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci) if random_action_prob > 0: masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks # Device which graphs in the minibatch will get their action randomized @@ -208,8 +208,7 @@ def sample_backward_from_graphs( self, graphs: List[Graph], model: Optional[nn.Module], - cond_info: Tensor, - dev: torch.device, + cond_info: Optional[Tensor], random_action_prob: float = 0.0, ): """Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution @@ -229,6 +228,7 @@ def sample_backward_from_graphs( Probability of taking a random action (only used if model parameterizes P_B) """ + dev = get_worker_device() n = len(graphs) done = [False] * n data = [ @@ -254,7 +254,8 @@ def not_done(lst): torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))] not_done_mask = torch.tensor(done, device=dev).logical_not() if model is not None: - _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask]) + ci = cond_info[not_done_mask] if cond_info is not None else None + _, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), ci) else: gbatch = self.ctx.collate(torch_graphs) action_types = self.ctx.bck_action_type_order diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 0d8948ac..bdd38aaa 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -140,7 +140,11 @@ def set_is_eval(self, is_eval: bool): self.is_eval = is_eval def create_training_data_from_own_samples( - self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float + self, + model: TrajectoryBalanceModel, + n: int, + cond_info: Optional[Tensor] = None, + random_action_prob: Optional[float] = 0.0, ): """Generate trajectories by sampling a model @@ -167,11 +171,12 @@ def create_training_data_from_own_samples( - is_valid: is the generated graph valid according to the env & ctx """ dev = get_worker_device() - cond_info = cond_info.to(dev) - data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob) - logZ_pred = model.logZ(cond_info) - for i in range(n): - data[i]["logZ"] = logZ_pred[i].item() + cond_info = cond_info.to(dev) if cond_info is not None else None + data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob) + if cond_info is not None: + logZ_pred = model.logZ(cond_info) + for i in range(n): + data[i]["logZ"] = logZ_pred[i].item() return data def create_training_data_from_graphs( @@ -204,7 +209,7 @@ def create_training_data_from_graphs( dev = get_worker_device() cond_info = cond_info.to(dev) return self.graph_sampler.sample_backward_from_graphs( - graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob + graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob ) trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs] for traj in trajs: @@ -333,6 +338,9 @@ def construct_batch(self, trajs, cond_info, log_rewards): batch.bck_ip_actions = torch.tensor(sum(bck_ipa, [])) batch.bck_ip_lens = torch.tensor([len(i) for i in bck_ipa]) + # compute_batch_losses expects these two optional values, if someone else doesn't fill them in, default to 0 + batch.num_offline = 0 + batch.num_online = 0 return batch def compute_batch_losses( @@ -358,7 +366,7 @@ def compute_batch_losses( clip_log_R = torch.maximum( log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev) ).float() - cond_info = batch.cond_info + cond_info = getattr(batch, 'cond_info', None) invalid_mask = 1 - batch.is_valid # This index says which trajectory each graph belongs to, so @@ -367,16 +375,18 @@ def compute_batch_losses( batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens) # The position of the last graph of each trajectory final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1 + # The per-state cond_info + batched_cond_info = cond_info[batch_idx] if cond_info is not None else None # Forward pass of the model, returns a GraphActionCategorical representing the forward # policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB). if self.cfg.do_parameterize_p_b: - fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx]) + fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info) else: if self.model_is_autoregressive: fwd_cat, per_graph_out = model(batch, cond_info, batched=True) else: - fwd_cat, per_graph_out = model(batch, cond_info[batch_idx]) + fwd_cat, per_graph_out = model(batch, batched_cond_info) # Retreive the reward predictions for the full graphs, # i.e. the final graph of each trajectory log_reward_preds = per_graph_out[final_graph_idx, 0] diff --git a/src/gflownet/models/graph_transformer.py b/src/gflownet/models/graph_transformer.py index 8c3993f0..e3ed5ece 100644 --- a/src/gflownet/models/graph_transformer.py +++ b/src/gflownet/models/graph_transformer.py @@ -1,4 +1,5 @@ from itertools import chain +from typing import Optional import torch import torch.nn as nn @@ -59,7 +60,7 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2) self.e2h = mlp(e_dim, num_emb, num_emb, 2) - self.c2h = mlp(g_dim, num_emb, num_emb, 2) + self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2) self.graph2emb = nn.ModuleList( sum( [ @@ -78,7 +79,7 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n ) ) - def forward(self, g: gd.Batch, cond: torch.Tensor): + def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): """Forward pass Parameters @@ -100,7 +101,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor): x = g.x o = self.x2h(x) e = self.e2h(g.edge_attr) - c = self.c2h(cond) + c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device)) num_total_nodes = g.x.shape[0] # Augment the edges with a new edge to the conditioning # information node. This new node is connected to every node @@ -221,7 +222,12 @@ def __init__( self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers) # TODO: flag for this - self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2) + self._logZ = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2) + + def logZ(self, cond_info: Optional[torch.Tensor]): + if cond_info is None: + return self._logZ(torch.ones((1, 1), device=self._logZ[0].weight.device)) + return self._logZ(cond_info) def _action_type_to_mask(self, t, g): return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device) @@ -244,7 +250,7 @@ def _make_cat(self, g, emb, action_types): types=action_types, ) - def forward(self, g: gd.Batch, cond: torch.Tensor): + def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]): node_embeddings, graph_embeddings = self.transf(g, cond) # "Non-edges" are edges not currently in the graph that we could add if hasattr(g, "non_edge_index"): diff --git a/src/gflownet/models/mxmnet.py b/src/gflownet/models/mxmnet.py index e104afb5..facb6b41 100644 --- a/src/gflownet/models/mxmnet.py +++ b/src/gflownet/models/mxmnet.py @@ -249,7 +249,6 @@ def compute_idx(pos, edge_index): unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) #Get central values full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() #init full index - import pdb; pdb.set_trace() #print('full_index', full_index) #Compute 1 diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index a5578dcf..ad854ebd 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -1,5 +1,5 @@ import socket -from typing import Callable, Dict, List, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -32,13 +32,11 @@ class SEHTask(GFNTask): def __init__( self, - dataset: Dataset, cfg: Config, - wrap_model: Callable[[nn.Module], nn.Module] = None, + wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None, ) -> None: - self._wrap_model = wrap_model + self._wrap_model = wrap_model if wrap_model is not None else (lambda x: x) self.models = self._load_task_models() - self.dataset = dataset self.temperature_conditional = TemperatureConditional(cfg) self.num_cond_dim = self.temperature_conditional.encoding_size() diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 88bc7420..7407e8f8 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -1,5 +1,5 @@ import pathlib -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import torch @@ -8,7 +8,7 @@ from rdkit.Chem import QED, Descriptors from rdkit.Chem.rdchem import Mol as RDMol from torch import Tensor -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from gflownet import LogScalar, ObjectProperties from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL @@ -63,16 +63,14 @@ class SEHMOOTask(SEHTask): def __init__( self, - dataset: Dataset, cfg: Config, - wrap_model: Callable[[nn.Module], nn.Module] = None, + wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None, ): - super().__init__(dataset, cfg, wrap_model) + super().__init__(cfg, wrap_model) self.cfg = cfg mcfg = self.cfg.task.seh_moo self.objectives = cfg.task.seh_moo.objectives cfg.cond.moo.num_objectives = len(self.objectives) # This value is used by the focus_cond and pref_cond - self.dataset = dataset if self.cfg.cond.focus_region.focus_type is not None: self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid) else: