Skip to content

Commit

Permalink
some renaming and None-proofing of cond_info
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Mar 13, 2024
1 parent 82ee6c0 commit 71c0d0d
Show file tree
Hide file tree
Showing 8 changed files with 83 additions and 35 deletions.
34 changes: 34 additions & 0 deletions docs/contributing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Contributing

Contributions to the repository are welcome, and we encourage you to open issues and pull requests. In general, it is recommended to fork this repository and open a pull request from your fork to the `trunk` branch. PRs are encouraged to be short and focused, and to include tests and documentation where appropriate.

## Installation

To install the developers dependencies run:
```
pip install -e '.[dev]' --find-links https://data.pyg.org/whl/torch-2.1.2+cu121.html
```

## Dependencies

Dependencies are defined in `pyproject.toml`, and frozen versions that are known to work are provided in `requirements/`.

To regenerate the frozen versions, run `./generate_requirements.sh <ENV-NAME>`. See comments within.

## Linting and testing

We use `tox` to run tests and linting, and `pre-commit` to run checks before committing.
To ensure that these checks pass, simply run `tox -e style` and `tox run` to run linters and tests, respectively.

`tox` itself runs many linters, but the most important ones are `black`, `ruff`, `isort`, and `mypy`. The full list
of linting tools is found in `.pre-commit-config.yaml`, while `tox.ini` defines the environments under which these
linters (as well as tests) are run.

## Github Actions

We use Github Actions to run tests and linting on every push and pull request. The configuration for these actions is found in `.github/workflows/`.

## Style Guide

On top of `black`-as-a-style-guide, we generally adhere to the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).
Our docstrings follow the [numpydoc](https://numpydoc.readthedocs.io/en/latest/format.html) format, and we use type hints throughout the codebase.
2 changes: 2 additions & 0 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Getting Started

17 changes: 9 additions & 8 deletions src/gflownet/algo/graph_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from gflownet.envs.graph_building_env import Graph, GraphAction, GraphActionCategorical, GraphActionType
from gflownet.models.graph_transformer import GraphTransformerGFN
from gflownet.utils.misc import get_worker_rng
from gflownet.utils.misc import get_worker_rng, get_worker_device


def relabel(g: Graph, ga: GraphAction):
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
self.pad_with_terminal_state = pad_with_terminal_state

def sample_from_model(
self, model: nn.Module, n: int, cond_info: Tensor, dev: torch.device, random_action_prob: float = 0.0
self, model: nn.Module, n: int, cond_info: Optional[Tensor], random_action_prob: float = 0.0
):
"""Samples a model in a minibatch
Expand All @@ -75,8 +75,6 @@ def sample_from_model(
Number of graphs to sample
cond_info: Tensor
Conditional information of each trajectory, shape (n, n_info)
dev: torch.device
Device on which data is manipulated
Returns
-------
Expand All @@ -87,6 +85,7 @@ def sample_from_model(
- bck_logprob: sum logprobs P_B
- is_valid: is the generated graph valid according to the env & ctx
"""
dev = get_worker_device()
# This will be returned
data = [{"traj": [], "reward_pred": None, "is_valid": True, "is_sink": []} for i in range(n)]
# Let's also keep track of trajectory statistics according to the model
Expand Down Expand Up @@ -114,7 +113,8 @@ def not_done(lst):
# Forward pass to get GraphActionCategorical
# Note about `*_`, the model may be outputting its own bck_cat, but we ignore it if it does.
# TODO: compute bck_cat.log_prob(bck_a) when relevant
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
ci = cond_info[not_done_mask] if cond_info is not None else None
fwd_cat, *_, log_reward_preds = model(self.ctx.collate(torch_graphs).to(dev), ci)
if random_action_prob > 0:
masks = [1] * len(fwd_cat.logits) if fwd_cat.masks is None else fwd_cat.masks
# Device which graphs in the minibatch will get their action randomized
Expand Down Expand Up @@ -208,8 +208,7 @@ def sample_backward_from_graphs(
self,
graphs: List[Graph],
model: Optional[nn.Module],
cond_info: Tensor,
dev: torch.device,
cond_info: Optional[Tensor],
random_action_prob: float = 0.0,
):
"""Sample a model's P_B starting from a list of graphs, or if the model is None, use a uniform distribution
Expand All @@ -229,6 +228,7 @@ def sample_backward_from_graphs(
Probability of taking a random action (only used if model parameterizes P_B)
"""
dev = get_worker_device()
n = len(graphs)
done = [False] * n
data = [
Expand All @@ -254,7 +254,8 @@ def not_done(lst):
torch_graphs = [self.ctx.graph_to_Data(graphs[i]) for i in not_done(range(n))]
not_done_mask = torch.tensor(done, device=dev).logical_not()
if model is not None:
_, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), cond_info[not_done_mask])
ci = cond_info[not_done_mask] if cond_info is not None else None
_, bck_cat, *_ = model(self.ctx.collate(torch_graphs).to(dev), ci)
else:
gbatch = self.ctx.collate(torch_graphs)
action_types = self.ctx.bck_action_type_order
Expand Down
30 changes: 20 additions & 10 deletions src/gflownet/algo/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ def set_is_eval(self, is_eval: bool):
self.is_eval = is_eval

def create_training_data_from_own_samples(
self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float
self,
model: TrajectoryBalanceModel,
n: int,
cond_info: Optional[Tensor] = None,
random_action_prob: Optional[float] = 0.0,
):
"""Generate trajectories by sampling a model
Expand All @@ -167,11 +171,12 @@ def create_training_data_from_own_samples(
- is_valid: is the generated graph valid according to the env & ctx
"""
dev = get_worker_device()
cond_info = cond_info.to(dev)
data = self.graph_sampler.sample_from_model(model, n, cond_info, dev, random_action_prob)
logZ_pred = model.logZ(cond_info)
for i in range(n):
data[i]["logZ"] = logZ_pred[i].item()
cond_info = cond_info.to(dev) if cond_info is not None else None
data = self.graph_sampler.sample_from_model(model, n, cond_info, random_action_prob)
if cond_info is not None:
logZ_pred = model.logZ(cond_info)
for i in range(n):
data[i]["logZ"] = logZ_pred[i].item()
return data

def create_training_data_from_graphs(
Expand Down Expand Up @@ -204,7 +209,7 @@ def create_training_data_from_graphs(
dev = get_worker_device()
cond_info = cond_info.to(dev)
return self.graph_sampler.sample_backward_from_graphs(
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, dev, random_action_prob
graphs, model if self.cfg.do_parameterize_p_b else None, cond_info, random_action_prob
)
trajs: List[Dict[str, Any]] = [{"traj": generate_forward_trajectory(i)} for i in graphs]
for traj in trajs:
Expand Down Expand Up @@ -333,6 +338,9 @@ def construct_batch(self, trajs, cond_info, log_rewards):
batch.bck_ip_actions = torch.tensor(sum(bck_ipa, []))
batch.bck_ip_lens = torch.tensor([len(i) for i in bck_ipa])

# compute_batch_losses expects these two optional values, if someone else doesn't fill them in, default to 0
batch.num_offline = 0
batch.num_online = 0
return batch

def compute_batch_losses(
Expand All @@ -358,7 +366,7 @@ def compute_batch_losses(
clip_log_R = torch.maximum(
log_rewards, torch.tensor(self.global_cfg.algo.illegal_action_logreward, device=dev)
).float()
cond_info = batch.cond_info
cond_info = getattr(batch, 'cond_info', None)
invalid_mask = 1 - batch.is_valid

# This index says which trajectory each graph belongs to, so
Expand All @@ -367,16 +375,18 @@ def compute_batch_losses(
batch_idx = torch.arange(num_trajs, device=dev).repeat_interleave(batch.traj_lens)
# The position of the last graph of each trajectory
final_graph_idx = torch.cumsum(batch.traj_lens, 0) - 1
# The per-state cond_info
batched_cond_info = cond_info[batch_idx] if cond_info is not None else None

# Forward pass of the model, returns a GraphActionCategorical representing the forward
# policy P_F, optionally a backward policy P_B, and per-graph outputs (e.g. F(s) in SubTB).
if self.cfg.do_parameterize_p_b:
fwd_cat, bck_cat, per_graph_out = model(batch, cond_info[batch_idx])
fwd_cat, bck_cat, per_graph_out = model(batch, batched_cond_info)
else:
if self.model_is_autoregressive:
fwd_cat, per_graph_out = model(batch, cond_info, batched=True)
else:
fwd_cat, per_graph_out = model(batch, cond_info[batch_idx])
fwd_cat, per_graph_out = model(batch, batched_cond_info)
# Retreive the reward predictions for the full graphs,
# i.e. the final graph of each trajectory
log_reward_preds = per_graph_out[final_graph_idx, 0]
Expand Down
16 changes: 11 additions & 5 deletions src/gflownet/models/graph_transformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import chain
from typing import Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n

self.x2h = mlp(x_dim + num_noise, num_emb, num_emb, 2)
self.e2h = mlp(e_dim, num_emb, num_emb, 2)
self.c2h = mlp(g_dim, num_emb, num_emb, 2)
self.c2h = mlp(max(1, g_dim), num_emb, num_emb, 2)
self.graph2emb = nn.ModuleList(
sum(
[
Expand All @@ -78,7 +79,7 @@ def __init__(self, x_dim, e_dim, g_dim, num_emb=64, num_layers=3, num_heads=2, n
)
)

def forward(self, g: gd.Batch, cond: torch.Tensor):
def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]):
"""Forward pass
Parameters
Expand All @@ -100,7 +101,7 @@ def forward(self, g: gd.Batch, cond: torch.Tensor):
x = g.x
o = self.x2h(x)
e = self.e2h(g.edge_attr)
c = self.c2h(cond)
c = self.c2h(cond if cond is not None else torch.ones((g.num_graphs, 1), device=g.x.device))
num_total_nodes = g.x.shape[0]
# Augment the edges with a new edge to the conditioning
# information node. This new node is connected to every node
Expand Down Expand Up @@ -221,7 +222,12 @@ def __init__(

self.emb2graph_out = mlp(num_glob_final, num_emb, num_graph_out, cfg.model.graph_transformer.num_mlp_layers)
# TODO: flag for this
self.logZ = mlp(env_ctx.num_cond_dim, num_emb * 2, 1, 2)
self._logZ = mlp(max(1, env_ctx.num_cond_dim), num_emb * 2, 1, 2)

def logZ(self, cond_info: Optional[torch.Tensor]):
if cond_info is None:
return self._logZ(torch.ones((1, 1), device=self._logZ[0].weight.device))
return self._logZ(cond_info)

def _action_type_to_mask(self, t, g):
return getattr(g, t.mask_name) if hasattr(g, t.mask_name) else torch.ones((1, 1), device=g.x.device)
Expand All @@ -244,7 +250,7 @@ def _make_cat(self, g, emb, action_types):
types=action_types,
)

def forward(self, g: gd.Batch, cond: torch.Tensor):
def forward(self, g: gd.Batch, cond: Optional[torch.Tensor]):
node_embeddings, graph_embeddings = self.transf(g, cond)
# "Non-edges" are edges not currently in the graph that we could add
if hasattr(g, "non_edge_index"):
Expand Down
1 change: 0 additions & 1 deletion src/gflownet/models/mxmnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,6 @@ def compute_idx(pos, edge_index):

unique, counts = torch.unique(edge_index[0], sorted=True, return_counts=True) #Get central values
full_index = torch.arange(0, edge_index[0].size()[0]).cuda().int() #init full index
import pdb; pdb.set_trace()
#print('full_index', full_index)

#Compute 1
Expand Down
8 changes: 3 additions & 5 deletions src/gflownet/tasks/seh_frag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import socket
from typing import Callable, Dict, List, Tuple
from typing import Callable, Dict, List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,13 +32,11 @@ class SEHTask(GFNTask):

def __init__(
self,
dataset: Dataset,
cfg: Config,
wrap_model: Callable[[nn.Module], nn.Module] = None,
wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None,
) -> None:
self._wrap_model = wrap_model
self._wrap_model = wrap_model if wrap_model is not None else (lambda x: x)
self.models = self._load_task_models()
self.dataset = dataset
self.temperature_conditional = TemperatureConditional(cfg)
self.num_cond_dim = self.temperature_conditional.encoding_size()

Expand Down
10 changes: 4 additions & 6 deletions src/gflownet/tasks/seh_frag_moo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pathlib
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
Expand All @@ -8,7 +8,7 @@
from rdkit.Chem import QED, Descriptors
from rdkit.Chem.rdchem import Mol as RDMol
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

from gflownet import LogScalar, ObjectProperties
from gflownet.algo.envelope_q_learning import EnvelopeQLearning, GraphTransformerFragEnvelopeQL
Expand Down Expand Up @@ -63,16 +63,14 @@ class SEHMOOTask(SEHTask):

def __init__(
self,
dataset: Dataset,
cfg: Config,
wrap_model: Callable[[nn.Module], nn.Module] = None,
wrap_model: Optional[Callable[[nn.Module], nn.Module]] = None,
):
super().__init__(dataset, cfg, wrap_model)
super().__init__(cfg, wrap_model)
self.cfg = cfg
mcfg = self.cfg.task.seh_moo
self.objectives = cfg.task.seh_moo.objectives
cfg.cond.moo.num_objectives = len(self.objectives) # This value is used by the focus_cond and pref_cond
self.dataset = dataset
if self.cfg.cond.focus_region.focus_type is not None:
self.focus_cond = FocusRegionConditional(self.cfg, mcfg.n_valid)
else:
Expand Down

0 comments on commit 71c0d0d

Please sign in to comment.