From cc12abcb16e71d0ed673c3980b4e219c43b1add2 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Fri, 4 Aug 2023 16:55:23 -0400 Subject: [PATCH 1/7] merge & squash to refresh branch --- docs/implementation_notes.md | 2 +- src/gflownet/algo/graph_sampling.py | 4 +- src/gflownet/algo/trajectory_balance.py | 15 +- src/gflownet/config.py | 1 + src/gflownet/data/sampling_iterator.py | 25 +- src/gflownet/envs/basic_graph_ctx.py | 198 +++++ src/gflownet/envs/graph_building_env.py | 3 +- src/gflownet/tasks/basic_graph_task.py | 917 ++++++++++++++++++++++++ src/gflownet/tasks/config.py | 20 + src/gflownet/tasks/seh_frag.py | 2 +- src/gflownet/tasks/seh_frag_moo.py | 2 +- src/gflownet/trainer.py | 9 +- 12 files changed, 1175 insertions(+), 23 deletions(-) create mode 100644 src/gflownet/envs/basic_graph_ctx.py create mode 100644 src/gflownet/tasks/basic_graph_task.py diff --git a/docs/implementation_notes.md b/docs/implementation_notes.md index 6930728d..55698169 100644 --- a/docs/implementation_notes.md +++ b/docs/implementation_notes.md @@ -33,4 +33,4 @@ The code contains a specific categorical distribution type for graph actions, `G Consider for example the `AddNode` and `SetEdgeAttr` actions, one applies to nodes and one to edges. An efficient way to produce logits for these actions would be to take the node/edge embeddings and project them (e.g. via an MLP) to a `(n_nodes, n_node_actions)` and `(n_edges, n_edge_actions)` tensor respectively. We thus obtain a list of tensors representing the logits of different actions, but logits are mixed between graphs in the minibatch, so one cannot simply apply a `softmax` operator on the tensor. -The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. \ No newline at end of file +The `GraphActionCategorical` class handles this and can be used to compute various other things, such as entropy, log probabilities, and so on; it can also be used to sample from the distribution. diff --git a/src/gflownet/algo/graph_sampling.py b/src/gflownet/algo/graph_sampling.py index 039fc158..6021876c 100644 --- a/src/gflownet/algo/graph_sampling.py +++ b/src/gflownet/algo/graph_sampling.py @@ -113,7 +113,9 @@ def not_done(lst): ] if self.sample_temp != 1: sample_cat = copy.copy(fwd_cat) - sample_cat.logits = [i / self.sample_temp for i in fwd_cat.logits] + sample_cat.logits = [ + i * m / self.sample_temp - 1000 * (1 - m) for i, m in zip(fwd_cat.logits, fwd_cat.masks) + ] actions = sample_cat.sample() else: actions = fwd_cat.sample() diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 22fe655e..21122d2a 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch_geometric.data as gd from torch import Tensor -from torch_scatter import scatter, scatter_sum +from torch_scatter import scatter, scatter_sum, scatter_logsumexp from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config @@ -309,11 +309,7 @@ def compute_batch_losses( # Indicate that the `batch` corresponding to each action is the above ip_log_prob = fwd_cat.log_prob(batch.ip_actions, batch=ip_batch_idces) # take the logsumexp (because we want to sum probabilities, not log probabilities) - # TODO: numerically stable version: - p = scatter(ip_log_prob.exp(), ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum") - # As a (reasonable) band-aid, ignore p < 1e-30, this will prevent underflows due to - # scatter(small number) = 0 on CUDA - log_p_F = p.clamp(1e-30).log() + log_p_F = scatter_logsumexp(ip_log_prob, ip_batch_idces, dim=0, dim_size=batch_idx.shape[0]) if self.cfg.do_parameterize_p_b: # Now we repeat this but for the backward policy @@ -321,10 +317,7 @@ def compute_batch_losses( batch.bck_ip_lens ) bck_ip_log_prob = bck_cat.log_prob(batch.bck_ip_actions, batch=bck_ip_batch_idces) - bck_p = scatter( - bck_ip_log_prob.exp(), bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0], reduce="sum" - ) - log_p_B = bck_p.clamp(1e-30).log() + log_p_B = scatter_logsumexp(bck_ip_log_prob, bck_ip_batch_idces, dim=0, dim_size=batch_idx.shape[0]) else: # Else just naively take the logprob of the actions we took log_p_F = fwd_cat.log_prob(batch.actions) @@ -496,7 +489,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): cumul_lens = torch.cumsum(torch.cat([torch.zeros(1, device=dev), traj_lengths]), 0).long() total_loss = torch.zeros(num_trajs, device=dev) ar = torch.arange(max_len, device=dev) - car = torch.cumsum(ar, 0) + car = torch.cumsum(ar, 0) if self.length_normalize_losses else torch.ones_like(ar) F_and_R = torch.cat([F, R]) R_start = F.shape[0] for ep in range(traj_lengths.shape[0]): diff --git a/src/gflownet/config.py b/src/gflownet/config.py index be4fa879..e1be215e 100644 --- a/src/gflownet/config.py +++ b/src/gflownet/config.py @@ -81,6 +81,7 @@ class Config: """ log_dir: str = MISSING + log_sampled_data: bool = True device: str = "cuda" seed: int = 0 validate_every: int = 1000 diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 90b8b4db..3163c8fd 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -1,5 +1,6 @@ import os import sqlite3 +import traceback from collections.abc import Iterable from copy import deepcopy from typing import Callable, List @@ -11,6 +12,7 @@ from rdkit import Chem, RDLogger from torch.utils.data import Dataset, IterableDataset +from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphActionCategorical @@ -112,9 +114,14 @@ def __init__( # don't want to initialize per-worker things just yet, such as where the log the worker writes # to. This must be done in __iter__, which is called by the DataLoader once this instance # has been copied into a new python process. - self.log_dir = log_dir + import warnings + + warnings.warn("Fix dependency on cfg.log_sampled_data") + self.log_dir = log_dir # if cfg.log_sampled_data else None self.log = SQLiteLog() self.log_hooks: List[Callable] = [] + # TODO: make this a proper flag / make a separate class for logging sampled molecules to a SQLite db + self.log_molecule_smis = not hasattr(self.ctx, "not_a_molecule_env") and self.log_dir is not None def add_log_hook(self, hook: Callable): self.log_hooks.append(hook) @@ -158,6 +165,14 @@ def __len__(self): return len(self.data) def __iter__(self): + try: + for x in self.iterator(): + yield x + except Exception as e: + traceback.print_exc() + raise e + + def iterator(self): worker_info = torch.utils.data.get_worker_info() self._wid = worker_info.id if worker_info is not None else 0 # Now that we know we are in a worker instance, we can initialize per-worker things @@ -189,9 +204,7 @@ def __iter__(self): else: # If we're not sampling the conditionals, then the idcs refer to listed preferences num_online = num_offline num_offline = 0 - cond_info = self.task.encode_conditional_information( - steer_info=torch.stack([self.data[i] for i in idcs]) - ) + cond_info = self.task.encode_conditional_information(torch.stack([self.data[i] for i in idcs])) trajs, flat_rewards = [], [] # Sample some on-policy data @@ -250,7 +263,7 @@ def __iter__(self): # note: we convert back into natural rewards for logging purposes # (allows to take averages and plot in objective space) # TODO: implement that per-task (in case they don't apply the same beta and log transformations) - rewards = torch.exp(log_rewards / cond_info["beta"]) + rewards = torch.exp(log_rewards / (cond_info["beta"] if "beta" in cond_info else 1.0)) if num_online > 0 and self.log_dir is not None: self.log_generated( deepcopy(trajs[num_offline:]), @@ -258,6 +271,8 @@ def __iter__(self): deepcopy(flat_rewards[num_offline:]), {k: v[num_offline:] for k, v in deepcopy(cond_info).items()}, ) + + extra_info = {} if num_online > 0: for hook in self.log_hooks: extra_info.update( diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py new file mode 100644 index 00000000..fd04e5e0 --- /dev/null +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -0,0 +1,198 @@ +from typing import Dict, List, Tuple + +import networkx as nx +import torch +import torch_geometric.data as gd +from networkx.algorithms.isomorphism import is_isomorphic as nx_is_isomorphic + +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnvContext, + graph_without_edge, +) +from gflownet.utils.graphs import random_walk_probs + + +def hashg(g): + return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v") + + +def is_isomorphic(u, v): + return nx_is_isomorphic(u, v, lambda a, b: a == b, lambda a, b: a == b) + + +class BasicGraphContext(GraphBuildingEnvContext): + """ + A basic graph generation context. + + This simple environment context is designed to be used to test implementations. It only allows for AddNode and + AddEdge actions, and is meant to be used within the BasicGraphTask to generate graphs of up to 7 nodes with + only two possible node attributes, making the state space a total of ~70k states (which is nicely enumerable + and allows us to compute p_theta(x) exactly for all x in the state space). + """ + + def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False): + self.max_nodes = max_nodes + self.output_gid = output_gid + + self.node_attr_values = { + "v": [0, 1], # Imagine this is as colors + } + self._num_rw_feat = 8 + + self.num_new_node_values = len(self.node_attr_values["v"]) + self.num_node_attr_logits = None + self.num_node_dim = self.num_new_node_values + 1 + self._num_rw_feat + self.num_node_attrs = 1 + self.num_edge_attr_logits = None + self.num_edge_attrs = 0 + self.num_cond_dim = num_cond_dim + self.num_edge_dim = 1 + self.edges_are_duplicated = True + self.edges_are_unordered = True + + # Order in which models have to output logits + self.action_type_order = [ + GraphActionType.Stop, + GraphActionType.AddNode, + GraphActionType.AddEdge, + ] + self.bck_action_type_order = [ + GraphActionType.RemoveNode, + GraphActionType.RemoveEdge, + ] + self.device = torch.device("cpu") + self.graph_data = graph_data + self.hash_to_graphs: Dict[str, int] = {} + if graph_data is not None: + states_hash = [hashg(i) for i in graph_data] + for i, h, g in zip(range(len(graph_data)), states_hash, graph_data): + self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)] + + def get_graph_idx(self, g, default=None): + h = hashg(g) + if h not in self.hash_to_graphs and default is not None: + return default + bucket = self.hash_to_graphs[h] + if len(bucket) == 1: + return bucket[0][1] + for i in bucket: + if is_isomorphic(i[0], g): + return i[1] + if default is not None: + return default + raise ValueError(g) + + def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True): + """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction""" + act_type, act_row, act_col = [int(i) for i in action_idx] + if fwd: + t = self.action_type_order[act_type] + else: + t = self.bck_action_type_order[act_type] + + if t is GraphActionType.Stop: + return GraphAction(t) + elif t is GraphActionType.AddNode: + return GraphAction(t, source=act_row, value=self.node_attr_values["v"][act_col]) + elif t is GraphActionType.AddEdge: + a, b = g.non_edge_index[:, act_row] + return GraphAction(t, source=a.item(), target=b.item()) + elif t is GraphActionType.RemoveNode: + return GraphAction(t, source=act_row) + elif t is GraphActionType.RemoveEdge: + a, b = g.edge_index[:, act_row * 2] + return GraphAction(t, source=a.item(), target=b.item()) + + def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int, int]: + """Translate a GraphAction to an index tuple""" + if action.action is GraphActionType.Stop: + row = col = 0 + type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.AddNode: + row = action.source + col = self.node_attr_values["v"].index(action.value) + type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.AddEdge: + # Here we have to retrieve the index in non_edge_index of an edge (s,t) + # that's also possibly in the reverse order (t,s). + # That's definitely not too efficient, can we do better? + row = ( + (g.non_edge_index.T == torch.tensor([(action.source, action.target)])).prod(1) + + (g.non_edge_index.T == torch.tensor([(action.target, action.source)])).prod(1) + ).argmax() + col = 0 + type_idx = self.action_type_order.index(action.action) + elif action.action is GraphActionType.RemoveNode: + row = action.source + col = 0 + type_idx = self.bck_action_type_order.index(action.action) + elif action.action is GraphActionType.RemoveEdge: + row = ((g.edge_index.T == torch.tensor([(action.source, action.target)])).prod(1)).argmax() + row = int(row) // 2 # edges are duplicated, but edge logits are not + col = 0 + type_idx = self.bck_action_type_order.index(action.action) + return (type_idx, int(row), int(col)) + + def graph_to_Data(self, g: Graph) -> gd.Data: + """Convert a networkx Graph to a torch geometric Data instance""" + x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat)) + x[0, -1] = len(g.nodes) == 0 + remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) + for i, n in enumerate(g.nodes): + ad = g.nodes[n] + x[i, self.node_attr_values["v"].index(ad["v"])] = 1 + if g.degree(n) <= 1: + remove_node_mask[i] = 1 + + remove_edge_mask = torch.zeros((len(g.edges), 1)) + for i, (u, v) in enumerate(g.edges): + if g.degree(u) > 1 and g.degree(v) > 1: + if nx.algorithms.is_connected(graph_without_edge(g, (u, v))): + remove_edge_mask[i] = 1 + edge_attr = torch.zeros((len(g.edges) * 2, self.num_edge_dim)) + edge_index = ( + torch.tensor([e for i, j in g.edges for e in [(i, j), (j, i)]], dtype=torch.long).reshape((-1, 2)).T + ) + gc = nx.complement(g) + non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T + gid = self.get_graph_idx(g) if self.output_gid else 0 + + return self._preprocess( + gd.Data( + x, + edge_index, + edge_attr, + non_edge_index=non_edge_index, + stop_mask=torch.ones((1, 1)), + add_node_mask=torch.ones((x.shape[0], self.num_new_node_values)) * (len(g) < self.max_nodes), + add_edge_mask=torch.ones((non_edge_index.shape[1], 1)), + remove_node_mask=remove_node_mask, + remove_edge_mask=remove_edge_mask, + gid=gid, + ) + ) + + def _preprocess(self, g: gd.Data) -> gd.Data: + if self._num_rw_feat > 0: + g.x = torch.cat([g.x, random_walk_probs(g, self._num_rw_feat, skip_odd=True)], 1) + return g + + def collate(self, graphs: List[gd.Data]): + """Batch Data instances""" + return gd.Batch.from_data_list(graphs, follow_batch=["edge_index", "non_edge_index"]) + + def mol_to_graph(self, obj: Graph) -> Graph: + return obj # This is already a graph + + def graph_to_mol(self, g: Graph) -> Graph: + # idem + return g + + def is_sane(self, g: Graph) -> bool: + return True + + def get_object_description(self, g: Graph, is_valid: bool) -> str: + return str(self.get_graph_idx(g, -1)) diff --git a/src/gflownet/envs/graph_building_env.py b/src/gflownet/envs/graph_building_env.py index fa7b284b..bd3c299d 100644 --- a/src/gflownet/envs/graph_building_env.py +++ b/src/gflownet/envs/graph_building_env.py @@ -331,7 +331,7 @@ def generate_forward_trajectory(g: Graph, max_nodes: int = None) -> List[Tuple[G # TODO: should this be a method of GraphBuildingEnv? handle set_node_attr flags and so on? gn = Graph() # Choose an arbitrary starting point, add to the stack - stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] + stack: List[Tuple[int, ...]] = [(np.random.randint(0, len(g.nodes)),)] if len(g.nodes) > 0 else [] traj = [] # This map keeps track of node labels in gn, since we have to start from 0 relabeling_map: Dict[int, int] = {} @@ -777,6 +777,7 @@ class GraphBuildingEnvContext: """A context class defines what the graphs are, how they map to and from data""" device: torch.device + num_cond_dim: int = 0 def aidx_to_GraphAction(self, g: gd.Data, action_idx: Tuple[int, int, int], fwd: bool = True) -> GraphAction: """Translate an action index (e.g. from a GraphActionCategorical) to a GraphAction diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py new file mode 100644 index 00000000..a79b43e1 --- /dev/null +++ b/src/gflownet/tasks/basic_graph_task.py @@ -0,0 +1,917 @@ +import os +import bz2 +import pickle +from typing import Dict, List, Tuple + +import networkx as nx +import numpy as np +import torch +import torch.nn as nn +import torch_geometric.data as gd +from networkx.algorithms.isomorphism import is_isomorphic +from torch import Tensor +from torch.utils.data import DataLoader, Dataset +from torch_scatter import scatter_logsumexp +from tqdm import tqdm + +from gflownet.algo.flow_matching import FlowMatching +from gflownet.algo.trajectory_balance import TrajectoryBalance +from gflownet.config import Config +from gflownet.envs.basic_graph_ctx import BasicGraphContext +from gflownet.envs.graph_building_env import ( + Graph, + GraphAction, + GraphActionType, + GraphBuildingEnv, + GraphActionCategorical, +) +from gflownet.models.graph_transformer import GraphTransformer, GraphTransformerGFN +from gflownet.trainer import FlatRewards, GFNAlgorithm, GFNTask, GFNTrainer, RewardScalar + + +def n_clique_reward(g, n=4): + cliques = list(nx.algorithms.clique.find_cliques(g)) + # The number of cliques each node belongs to + num_cliques = np.bincount(sum(cliques, [])) + cliques_match = [len(i) == n for i in cliques] + return np.mean(cliques_match) - np.mean(num_cliques) + + +def colored_n_clique_reward(g, n=4): + cliques = list(nx.algorithms.clique.find_cliques(g)) + # The number of cliques each node belongs to + num_cliques = np.bincount(sum(cliques, [])) + colors = {i: g.nodes[i]["v"] for i in g.nodes} + + def color_match(c): + return np.bincount([colors[i] for i in c]).max() >= n - 1 + + cliques_match = [float(len(i) == n) * (1 if color_match(i) else 0.5) for i in cliques] + return np.maximum(np.sum(cliques_match) - np.sum(num_cliques) + len(g) - 1, -10) + + +def even_neighbors_reward(g): + total_correct = 0 + for n in g: + num_diff_colr = 0 + c = g.nodes[n]["v"] + for i in g.neighbors(n): + num_diff_colr += int(g.nodes[i]["v"] != c) + total_correct += int(num_diff_colr % 2 == 0) - (1 if num_diff_colr == 0 else 0) + return np.float32((total_correct - len(g.nodes) if len(g.nodes) > 3 else -5) * 10 / 7) + + +def count_reward(g): + ncols = np.bincount([g.nodes[i]["v"] for i in g], minlength=2) + return np.float32(-abs(ncols[0] + ncols[1] / 2 - 3) / 4 * 10) + + +def generate_two_col_data(data_root, max_nodes=7): + atl = nx.generators.atlas.graph_atlas_g() + # Filter out disconnected graphs + conn = [i for i in atl if 1 <= len(i.nodes) <= max_nodes and nx.is_connected(i)] + # Create all possible two-colored graphs + two_col_graphs = [nx.Graph()] + print(len(conn)) + pb = tqdm(range(117142), disable=None) + hashes = {} + rejected = 0 + + def node_eq(a, b): + return a == b + + for g in conn: + for i in range(2 ** len(g.nodes)): + g = g.copy() + for j in range(len(g.nodes)): + bit = i % 2 + i //= 2 + g.nodes[j]["v"] = bit + h = nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v") + if h not in hashes: + hashes[h] = [g] + two_col_graphs.append(g) + else: + if not any(nx.algorithms.isomorphism.is_isomorphic(g, gp, node_eq) for gp in hashes[h]): + hashes[h].append(g) + two_col_graphs.append(g) + else: + pb.set_description(f"{rejected}", refresh=False) + rejected += 1 + pb.update(1) + with bz2.open(data_root + f"/two_col_{max_nodes}_graphs.pkl.bz", "wb") as f: + pickle.dump(two_col_graphs, f) + return two_col_graphs + + +def load_two_col_data(data_root, max_nodes=7, generate_if_missing=True): + p = data_root + f"/two_col_{max_nodes}_graphs.pkl.bz" + print("Loading", p) + if not os.path.exists(p) and generate_if_missing: + return generate_two_col_data(data_root, max_nodes=max_nodes) + with bz2.open(p, "rb") as f: + data = pickle.load(f) + return data + + +class GraphTransformerRegressor(GraphTransformer): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self.g2o = torch.nn.Linear(kw["num_emb"] * 2, 1) + + def forward(self, g: gd.Batch, cond: torch.Tensor): + per_node_pred, per_graph_pred = super().forward(g, cond) + return self.g2o(per_graph_pred)[:, 0] + + +class TwoColorGraphDataset(Dataset): + def __init__( + self, + data, + ctx, + train=True, + output_graphs=False, + split_seed=142857, + ratio=0.9, + max_nodes=7, + reward_func="cliques", + ): + self.data = data + self.ctx = ctx + self.output_graphs = output_graphs + self.reward_func = reward_func + self.idcs = [0] + self.max_nodes = max_nodes + if data is None: + return + + idcs = np.arange(len(data)) + rng = np.random.default_rng(split_seed) + rng.shuffle(idcs) + if train: + self.idcs = idcs[: int(np.floor(ratio * len(data)))] + else: + self.idcs = idcs[int(np.floor(ratio * len(data))) :] + + print(train, self.idcs.shape) + self._gc = nx.complete_graph(7) + self._enum_edges = list(self._gc.edges) + self.compute_Fsa = False + self.compute_normalized_Fsa = False + self.regress_to_F = False + + def __len__(self): + return len(self.idcs) + + def reward(self, g): + if len(g.nodes) > self.max_nodes: + return -100 + if self.reward_func == "cliques": + return colored_n_clique_reward(g) + elif self.reward_func == "even_neighbors": + return even_neighbors_reward(g) + elif self.reward_func == "count": + return count_reward(g) + elif self.reward_func == "const": + return np.float32(0) + + def collate_fn(self, batch): + graphs, rewards, idcs = zip(*batch) + batch = self.ctx.collate(graphs) + if self.regress_to_F: + batch.y = torch.as_tensor([self.epc.mdp_graph.nodes[i]["F"] for i in idcs]) + else: + batch.y = torch.as_tensor(rewards) + if self.compute_Fsa: + all_targets = [] + for data_idx in idcs: + targets = [ + torch.zeros_like(getattr(self.epc._Data[data_idx], i.mask_name)) - 100 + for i in self.ctx.action_type_order + ] + for neighbor in list(self.epc.mdp_graph.neighbors(data_idx)): + for _, edge in self.epc.mdp_graph.get_edge_data(data_idx, neighbor).items(): + a, F = edge["a"], edge["F"] + targets[a[0]][a[1], a[2]] = F + if self.compute_normalized_Fsa: + logZ = torch.log(sum([i.exp().sum() for i in targets])) + targets = [i - logZ for i in targets] + all_targets.append(targets) + batch.y = torch.cat([torch.cat(i).flatten() for i in zip(*all_targets)]) + return batch + + def __getitem__(self, idx): + idx = self.idcs[idx] + g = self.data[idx] + r = torch.tensor(self.reward(g).reshape((1,))) + if self.output_graphs: + return self.ctx.graph_to_Data(g), r, idx + else: + return g, r + + +class BasicGraphTask(GFNTask): + def __init__( + self, + cfg: Config, + dataset: TwoColorGraphDataset, + ): + self.dataset = dataset + self.cfg = cfg + + def flat_reward_transform(self, y: Tensor) -> FlatRewards: + return FlatRewards(y.float()) + + def sample_conditional_information(self, n: int, train_it: int = 0): + return {"encoding": torch.zeros((n, 1))} + + def cond_info_to_logreward(self, cond_info: Dict[str, Tensor], flat_reward: FlatRewards) -> RewardScalar: + return RewardScalar(flat_reward[:, 0].float()) + + def compute_flat_rewards(self, mols: List[Graph]) -> Tuple[FlatRewards, Tensor]: + if not len(mols): + return FlatRewards(torch.zeros((0, 1))), torch.zeros((0,)).bool() + is_valid = torch.ones(len(mols)).bool() + flat_rewards = torch.tensor([self.dataset.reward(i) for i in mols]).float().reshape((-1, 1)) + return FlatRewards(flat_rewards), is_valid + + def encode_conditional_information(self, info): + encoding = torch.zeros((len(info), 1)) + return {"beta": torch.ones(len(info)), "encoding": encoding.float(), "preferences": info.float()} + + +class UnpermutedGraphEnv(GraphBuildingEnv): + """When using a tabular model, we want to always give the same graph node order, this environment + just makes sure that happens""" + + def set_epc(self, epc): + self.epc = epc + + def step(self, g: Graph, ga: GraphAction): + g = super().step(g, ga) + # get_graph_idx hashes the graph and so returns the same idx for the same graph up to isomorphism/node order + return self.epc.states[self.epc.get_graph_idx(g)] + + +class BasicGraphTaskTrainer(GFNTrainer): + cfg: Config + training_data: TwoColorGraphDataset + test_data: TwoColorGraphDataset + + def set_default_hps(self, cfg: Config): + cfg.opt.learning_rate = 1e-4 + cfg.opt.weight_decay = 1e-8 + cfg.opt.momentum = 0.9 + cfg.opt.adam_eps = 1e-8 + cfg.opt.lr_decay = 20000 + cfg.opt.clip_grad_param = 10 + cfg.opt.clip_grad_type = "none" # "norm" + cfg.algo.max_nodes = 7 + cfg.algo.global_batch_size = 64 + cfg.model.num_emb = 96 + cfg.model.num_layers = 8 + cfg.algo.valid_offline_ratio = 0 + cfg.algo.tb.do_correct_idempotent = True # Important to converge to the true p(x) + cfg.algo.tb.do_subtb = True + cfg.algo.tb.do_parameterize_p_b = False + cfg.algo.illegal_action_logreward = -30 # Although, all states are legal here, this shouldn't matter + cfg.num_workers = 8 + cfg.algo.train_random_action_prob = 0.01 + cfg.log_sampled_data = False + # Because we're using a RepeatedPreferencesDataset + cfg.algo.valid_sample_cond_info = False + cfg.algo.offline_ratio = 0 + + def setup(self): + mcfg = self.cfg.task.basic_graph + max_nodes = self.cfg.algo.max_nodes + print(self.cfg.log_dir) + self.rng = np.random.default_rng(142857) + if mcfg.do_tabular_model: + self.env = UnpermutedGraphEnv() + else: + self.env = GraphBuildingEnv() + self._data = load_two_col_data(self.cfg.task.basic_graph.data_root, max_nodes=max_nodes) + self.ctx = BasicGraphContext(max_nodes, num_cond_dim=1, graph_data=self._data, output_gid=True) + self._do_supervised = self.cfg.task.basic_graph.do_supervised + + self.training_data = TwoColorGraphDataset( + self._data, self.ctx, train=True, ratio=mcfg.train_ratio, max_nodes=max_nodes, reward_func=mcfg.reward_func + ) + self.test_data = TwoColorGraphDataset( + self._data, self.ctx, train=False, ratio=mcfg.train_ratio, max_nodes=max_nodes, reward_func=mcfg.reward_func + ) + + self.exact_prob_cb = ExactProbCompCallback( + self, + self.training_data.data, + self.device, + cache_root=self.cfg.task.basic_graph.data_root, + cache_path=self.cfg.task.basic_graph.data_root + f"/two_col_epc_cache_{max_nodes}.pkl", + ) + if mcfg.do_tabular_model: + self.env.set_epc(self.exact_prob_cb) + + if self._do_supervised and not self.cfg.task.basic_graph.regress_to_Fsa: + model = GraphTransformerRegressor( + x_dim=self.ctx.num_node_dim, + e_dim=self.ctx.num_edge_dim, + g_dim=1, + num_emb=self.cfg.model.num_emb, + num_layers=self.cfg.model.num_layers, + num_heads=self.cfg.model.graph_transformer.num_heads, + ln_type=self.cfg.model.graph_transformer.ln_type, + ) + elif mcfg.do_tabular_model: + model = TabularHashingModel(self.exact_prob_cb) + if 1: + model.load_state_dict( + torch.load( + "/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_9/model_state.pt" + )["models_state_dict"][0] + ) + else: + model = GraphTransformerGFN( + self.ctx, + self.cfg, + do_bck=self.cfg.algo.tb.do_parameterize_p_b, + ) + if not self._do_supervised: + self.test_data = RepeatedPreferenceDataset(np.zeros((32, 1)), 8) + + self.model = self.sampling_model = model + params = [i for i in self.model.parameters()] + if self.cfg.opt.opt == "adam": + self.opt = torch.optim.Adam( + params, + self.cfg.opt.learning_rate, + (self.cfg.opt.momentum, 0.999), + weight_decay=self.cfg.opt.weight_decay, + eps=self.cfg.opt.adam_eps, + ) + elif self.cfg.opt.opt == "SGD": + self.opt = torch.optim.SGD( + params, self.cfg.opt.learning_rate, self.cfg.opt.momentum, weight_decay=self.cfg.opt.weight_decay + ) + self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) + + algo = self.cfg.algo.method + if algo == "TB" or algo == "subTB": + self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) + self.algo.graph_sampler.sample_temp = 100 + elif algo == "FM": + self.algo = FlowMatching(self.env, self.ctx, self.rng, self.cfg) + self.task = BasicGraphTask( + self.cfg, + self.training_data, + ) + self.sampling_tau = self.cfg.algo.sampling_tau + self.mb_size = self.cfg.algo.global_batch_size + self.clip_grad_param = self.cfg.opt.clip_grad_param + self.clip_grad_callback = { + "value": (lambda params: torch.nn.utils.clip_grad_value_(params, self.clip_grad_param)), + "norm": (lambda params: torch.nn.utils.clip_grad_norm_(params, self.clip_grad_param)), + "none": (lambda x: None), + }[self.cfg.opt.clip_grad_type] + + self.algo.task = self.task + if self.cfg.task.basic_graph.test_split_type == "random": + pass + elif self.cfg.task.basic_graph.test_split_type == "bck_traj": + train_idcs, test_idcs = self.exact_prob_cb.get_bck_trajectory_test_split( + self.cfg.task.basic_graph.train_ratio + ) + self.training_data.idcs = train_idcs + self.test_data.idcs = test_idcs + elif self.cfg.task.basic_graph.test_split_type == "subtrees": + train_idcs, test_idcs = self.exact_prob_cb.get_subtree_test_split( + self.cfg.task.basic_graph.train_ratio, self.cfg.task.basic_graph.test_split_seed + ) + self.training_data.idcs = train_idcs + self.test_data.idcs = test_idcs + if not self._do_supervised or self.cfg.task.basic_graph.regress_to_Fsa: + self._callbacks = {"true_px_error": self.exact_prob_cb} + else: + self._callbacks = {} + + def build_callbacks(self): + return self._callbacks + + def step(self, loss: Tensor): + loss.backward() + for i in self.model.parameters(): + self.clip_grad_callback(i) + self.opt.step() + self.opt.zero_grad() + self.lr_sched.step() + if self.sampling_tau > 0: + for a, b in zip(self.model.parameters(), self.sampling_model.parameters()): + b.data.mul_(self.sampling_tau).add_(a.data * (1 - self.sampling_tau)) + + +class RepeatedPreferenceDataset(TwoColorGraphDataset): + def __init__(self, preferences, repeat): + self.prefs = preferences + self.repeat = repeat + + def __len__(self): + return len(self.prefs) * self.repeat + + def __getitem__(self, idx): + assert 0 <= idx < len(self) + return torch.tensor(self.prefs[int(idx // self.repeat)]) + + +def hashg(g): + return nx.algorithms.graph_hashing.weisfeiler_lehman_graph_hash(g, node_attr="v") + + +class TabularHashingModel(torch.nn.Module): + """A tabular model to ensure that the objectives converge to the correct solution.""" + + def __init__(self, epc): + super().__init__() + self.epc = epc + self.action_types = [GraphActionType.Stop, GraphActionType.AddNode, GraphActionType.AddEdge] + # This makes a big array which is then sliced and reshaped into logits. + # We're using the masks's shapes to determine the size of the table because they're the same shape + # as the logits. The [1] is the F(s) prediction used for SubTB. + num_total = 0 + self.slices = [] + self.shapes = [] + print("Making table...") + for gid in tqdm(range(len(self.epc.states))): + this_slice = [num_total] + self.shapes.append( + [ + epc._Data[gid].stop_mask.shape, + epc._Data[gid].add_node_mask.shape, + epc._Data[gid].add_edge_mask.shape, + [1], + ] + ) + ns = [np.prod(i) for i in self.shapes[-1]] + this_slice += list(np.cumsum(ns) + num_total) + num_total += sum(ns) + self.slices.append(this_slice) + self.table = nn.Parameter(torch.zeros((num_total,))) + # For TB we have to have a unique parameter for logZ + self._logZ = nn.Parameter(torch.zeros((1,))) + print("Made table of size", num_total) + + def __call__(self, g: gd.Batch, cond_info): + """This ignores cond_info, which we don't use anyways for now, but beware""" + ns = [self.slices[i] for i in g.gid.cpu()] + shapes = [self.shapes[i] for i in g.gid.cpu()] + items = [[self.table[a:b].reshape(s) for a, b, s in zip(n, n[1:], ss)] for n, ss in zip(ns, shapes)] + logits = zip(*[i[0:3] for i in items]) + logF_s = torch.stack([i[-1] for i in items]) + masks = [GraphTransformerGFN._action_type_to_mask(None, t, g) for t in self.action_types] + return ( + GraphActionCategorical( + g, + logits=[torch.cat(i, 0) * m - 1000 * (1 - m) for i, m in zip(logits, masks)], + keys=[ + GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[t]] + for t in self.action_types + ], + masks=masks, + types=self.action_types, + ), + logF_s, + ) + + def logZ(self, cond_info: Tensor): + return self._logZ.tile(cond_info.shape[0]).reshape((-1, 1)) # Why is the reshape necessary? + + +class ExactProbCompCallback: + ctx: BasicGraphContext + trial: BasicGraphTaskTrainer + mdp_graph: nx.DiGraph + + def __init__( + self, + trial, + states, + dev, + mbs=128, + cache_root=None, + cache_path=None, + do_save_px=True, + log_rewards=None, + tqdm_disable=None, + ctx=None, + env=None, + ): + self.trial = trial + self.ctx = trial.ctx if trial is not None else ctx + self.env = trial.env if trial is not None else env + self.mbs = mbs + self.dev = dev + self.states = states + self.cache_root = cache_root + self.cache_path = cache_path + self.mdp_graph = None + if self.cache_path is not None: + self.load_cache(self.cache_path) + if log_rewards is None: + self.log_rewards = np.array( + [self.trial.training_data.reward(i) for i in tqdm(self.states, disable=tqdm_disable)] + ) + else: + self.log_rewards = log_rewards + self.logZ = np.log(np.sum(np.exp(self.log_rewards))) + self.true_log_probs = self.log_rewards - self.logZ + # This is reward-dependent + if self.mdp_graph is not None: + self.recompute_flow() + self.do_save_px = do_save_px + if do_save_px: + os.makedirs(self.trial.cfg.log_dir, exist_ok=True) + self._save_increment = 0 + + def load_cache(self, cache_path): + print("Loading cache @", cache_path) + cache = torch.load(open(cache_path, "rb")) + self.mdp_graph = cache["mdp"] + self._Data = cache["Data"] + self._hash_to_graphs = cache["hashmap"] + bs, ids = cache["batches"], cache["idces"] + print("Done") + self.precomputed_batches, self.precomputed_indices = ( + [i.to(self.dev) for i in bs], + [[(j[0].to(self.dev), j[1].to(self.dev)) for j in i] for i in ids], + ) + + def on_validation_end(self, metrics): + # Compute exact sampling probabilities of the model, last probability is p(illegal), remove it. + log_probs = self.compute_prob(self.trial.model).cpu().numpy()[:-1] + lp, p = log_probs, np.exp(log_probs) + lq, q = self.true_log_probs, np.exp(self.true_log_probs) + metrics["L1_logpx_error"] = np.mean(abs(lp - lq)) + metrics["JS_divergence"] = (p * (lp - lq) + q * (lq - lp)).sum() / 2 + print("L1 logpx error", metrics["L1_logpx_error"], "JS divergence", metrics["JS_divergence"]) + if self.do_save_px: + torch.save(log_probs, open(self.trial.cfg.log_dir + f"/log_px_{self._save_increment}.pt", "wb")) + self._save_increment += 1 + + def get_graph_idx(self, g, default=None): + def iso(u, v): + return is_isomorphic(u, v, lambda a, b: a == b, lambda a, b: a == b) + + h = hashg(g) + if h not in self._hash_to_graphs: + if default is not None: + return default + else: + print("Graph not found in cache", h) + for i in g.nodes: + print(i, g.nodes[i]) + for i in g.edges: + print(i, g.edges[i]) + bucket = self._hash_to_graphs[h] + if len(bucket) == 1: + return bucket[0] + for i in bucket: + if iso(self.states[i], g): + return i + if default is not None: + return default + raise ValueError(g) + + def compute_cache(self, tqdm_disable=None): + states, mbs, dev = self.states, self.mbs, self.dev + mdp_graph = nx.MultiDiGraph() + self.precomputed_batches = [] + self.precomputed_indices = [] + self._hash_to_graphs = {} + states_hash = [hashg(i) for i in tqdm(states, disable=tqdm_disable)] + self._Data = states_Data = gd.Batch.from_data_list( + [self.ctx.graph_to_Data(i) for i in tqdm(states, disable=tqdm_disable)] + ) + for i, h in enumerate(states_hash): + self._hash_to_graphs[h] = self._hash_to_graphs.get(h, list()) + [i] + + for bi in tqdm(range(0, len(states), mbs), disable=tqdm_disable): + bs = states[bi : bi + mbs] + bD = states_Data[bi : bi + mbs] + indices = list(range(bi, bi + len(bs))) + # TODO: if the environment's masks are well designed, this non_terminal business shouldn't be necessary + # non_terminals = [(i, j, k) for i, j, k in zip(bs, bD, indices) if not self.is_terminal(i)] + # if not len(non_terminals): + # self.precomputed_batches.append(None) + # self.precomputed_indices.append(None) + # continue + # bs, bD, indices = zip(*non_terminals) + batch = self.ctx.collate(bD).to(dev) + self.precomputed_batches.append(batch) + + # with torch.no_grad(): + # cat, *_, mo = self.trial.model(batch, ones[:len(bs)]) + actions = [[] for i in range(len(bs))] + offset = 0 + for u, i in enumerate(ctx.action_type_order): + # /!\ This assumes mask.shape == cat.logit[i].shape + mask = getattr(batch, i.mask_name) + batch_key = GraphTransformerGFN._graph_part_to_key[GraphTransformerGFN._action_type_to_graph_part[i]] + batch_idx = ( + getattr(batch, f"{batch_key}_batch" if batch_key != "x" else "batch") + if batch_key is not None + else torch.arange(batch.num_graphs, device=dev) + ) + mslice = ( + batch._slice_dict[batch_key] + if batch_key is not None + else torch.arange(batch.num_graphs + 1, device=dev) + ) + for j in mask.nonzero().cpu().numpy(): + # We're using nonzero above to enumerate all positions, but we still need to check + # if the mask is nonzero since we only want the legal actions. + # We *don't* wan't mask.nonzero() because then `k` would be wrong + k = j[0] * mask.shape[1] + j[1] + offset + jb = batch_idx[j[0]].item() + actions[jb].append((u, j[0] - mslice[jb].item(), j[1], k)) + offset += mask.numel() + all_indices = [] + for jb, j_acts in enumerate(actions): + end_indices = [] + being_indices = [] + for *a, srcidx in j_acts: + idx = indices[jb] + sp = self.env.step(bs[jb], self.ctx.aidx_to_GraphAction(bD[jb], a[:3])) if a[0] != 0 else bs[jb] + spidx = self.get_graph_idx(sp, len(states)) + if a[0] == 0 or spidx >= len(states): + end_indices.append((idx, spidx, srcidx)) + mdp_graph.add_edge(idx, spidx, srci=srcidx, term=True, a=a) + else: + being_indices.append((idx, spidx, srcidx)) + mdp_graph.add_edge(idx, spidx, srci=srcidx, term=False, a=a) + all_indices.append((torch.tensor(end_indices).T.to(dev), torch.tensor(being_indices).T.to(dev))) + self.precomputed_indices.append(all_indices) + self.mdp_graph = mdp_graph + + def save_cache(self, path): + with open(path, "wb") as f: + torch.save( + { + "batches": [i.cpu() for i in self.precomputed_batches], + "idces": [[(j[0].cpu(), j[1].cpu()) for j in i] for i in self.precomputed_indices], + "Data": self._Data, + "mdp": self.mdp_graph, + "hashmap": self._hash_to_graphs, + }, + f, + ) + + def compute_prob(self, model, cond_info=None, tqdm_disable=None): + # +1 to count illegal actions prob (may not be applicable to well-masked envs) + prob_of_being_t = torch.zeros(len(self.states) + 1).to(self.dev) - 100 + prob_of_being_t[0] = 0 + prob_of_ending_t = torch.zeros(len(self.states) + 1).to(self.dev) - 100 + if cond_info is None: + cond_info = torch.zeros((self.mbs, self.ctx.num_cond_dim)).to(self.dev) + if cond_info.ndim == 1: + cond_info = cond_info[None, :] * torch.ones((self.mbs, 1)).to(self.dev) + if cond_info.ndim == 2 and cond_info.shape[0] == 1: + cond_info = cond_info * torch.ones((self.mbs, 1)).to(self.dev) + # Note: visiting the states in order works because the ordering here is a natural topological sort. + # Wrong results otherwise. + for bi, batch, pre_indices in zip( + tqdm(range(0, len(self.states), self.mbs), disable=tqdm_disable), + self.precomputed_batches, + self.precomputed_indices, + ): + bs = self.states[bi : bi + self.mbs] + # This isn't even right: + # indices = list(range(bi, bi + len(bs))) + # non_terminals = [(i, j) for i, j in zip(bs, indices) if not self.is_terminal(i)] + # if not len(non_terminals): + # continue + # bs, indices = zip(*non_terminals) + with torch.no_grad(): + cat, *_, mo = model(batch, cond_info[: len(bs)]) + logprobs = torch.cat([i.flatten() for i in cat.logsoftmax()]) + for end_indices, being_indices in pre_indices: + if being_indices.shape[0] > 0: + s_idces, sp_idces, a_idces = being_indices + src = prob_of_being_t[s_idces] + logprobs[a_idces] + inter = scatter_logsumexp(src, sp_idces, dim_size=prob_of_being_t.shape[-1]) + prob_of_being_t = torch.logaddexp(inter, prob_of_being_t) + # prob_of_being_t = scatter_add( + # (prob_of_being_t[s_idces] + logprobs[a_idces]).exp(), sp_idces, out=prob_of_being_t.exp() + # ).log() + if end_indices.shape[0] > 0: + s_idces, sp_idces, a_idces = end_indices + src = prob_of_being_t[s_idces] + logprobs[a_idces] + inter = scatter_logsumexp(src, sp_idces, dim_size=prob_of_ending_t.shape[-1]) + prob_of_ending_t = torch.logaddexp(inter, prob_of_ending_t) + # prob_of_ending_t = scatter_add( + # (prob_of_being_t[s_idces] + logprobs[a_idces]).exp(), sp_idces, out=prob_of_ending_t.exp() + # ).log() + return prob_of_ending_t + + def recompute_flow(self, tqdm_disable=None): + g = self.mdp_graph + for i in g: + g.nodes[i]["F"] = -100 + for i in tqdm(list(range(len(g)))[::-1], disable=tqdm_disable): + p = sorted(list(g.predecessors(i)), reverse=True) + num_back = len([n for n in p if n != i]) + for j in p: + if j == i: + g.nodes[j]["F"] = np.logaddexp(g.nodes[j]["F"], self.log_rewards[j]) + g.edges[(i, i, 0)]["F"] = self.log_rewards[j].item() + else: + backflow = np.log(np.exp(g.nodes[i]["F"]) / num_back) + g.nodes[j]["F"] = np.logaddexp(g.nodes[j]["F"], backflow) + # Here we're making a decision to split flow backwards equally for idempotent actions + # from the same state. I think it's ok? + ed = g.get_edge_data(j, i) + for k, vs in ed.items(): + g.edges[(j, i, k)]["F"] = np.log(np.exp(backflow) / len(ed)) + + def get_bck_trajectory_test_split(self, r, seed=142857): + test_set = set() + n = int((1 - r) * len(self.states)) + np.random.seed(seed) + while len(test_set) < n: + i0 = np.random.randint(len(self.states)) + s0 = self.states[i0] + if len(s0.nodes) < 7: # TODO: unhardcode this + continue + s = s0 + idx = i0 + while len(s.nodes) > 5: # TODO: unhardcode this + test_set.add(idx) + actions = [ + (u, a.item(), b.item()) + for u, ra in enumerate(self.ctx.bck_action_type_order) + for a, b in getattr(self._Data[idx], ra.mask_name).nonzero() + ] + action = actions[np.random.randint(len(actions))] + gaction = self.ctx.aidx_to_GraphAction(self._Data[idx], action, fwd=False) + s = self.env.step(s, gaction) + idx = self.get_graph_idx(s) # This finds the graph index taking into account isomorphism + s = self.states[idx] # We still have to get the original graph so that the Data instance is correct + train_set = list(set(range(len(self.states))).difference(test_set)) + test_set = list(test_set) + np.random.shuffle(train_set) + return train_set, test_set + + def get_subtree_test_split(self, r, seed=142857): + cache_path = f"{self.cache_root}/subtree_split_{r}_{seed}.pkl" + if self.cache_root is not None: + if os.path.exists(cache_path): + return pickle.load(open(cache_path, "rb")) + test_set = set() + n = int((1 - r) * len(self.states)) + np.random.seed(seed) + start_states_idx, available_start_states, start_states = [], [], [] + edge_limit = 11 + while len(test_set) < n: + num_ss = len([i for i in start_states_idx if i not in test_set]) + if num_ss == 0 or len(available_start_states) == 0: + start_states, start_states_idx = zip( + *[(s0, i) for i, s0 in enumerate(self.states) if len(s0.nodes) == 6 and len(s0.edges) >= edge_limit] + ) + available_start_states = list(range(len(start_states))) + edge_limit -= 1 + assi = np.random.randint(len(available_start_states)) + ssi = available_start_states.pop(assi) + s0 = start_states[ssi] + i0 = self.get_graph_idx(s0) + if i0 in test_set: + continue + stack = [(s0, i0)] + while len(stack): + s, i = stack.pop() + if i in test_set: + continue + test_set.add(i) + actions = [ + (u, a.item(), b.item()) + for u, ra in enumerate(self.ctx.action_type_order) + if ra != GraphActionType.Stop + for a, b in getattr(self._Data[i], ra.mask_name).nonzero() + ] + for action in actions: + gaction = self.ctx.aidx_to_GraphAction(self._Data[i], action, fwd=True) + sp = self.env.step(s, gaction) + ip = self.get_graph_idx(sp) # This finds the graph index taking into account isomorphism + if ip in test_set: + continue + sp = self.states[ip] # We still have to get the original graph so that the Data instance is correct + stack.append((sp, ip)) + train_set = list(set(range(len(self.states))).difference(test_set)) + test_set = list(test_set) + np.random.shuffle(train_set) + if self.cache_root is not None: + pickle.dump((np.array(train_set), np.array(test_set)), open(cache_path, "wb")) + return train_set, test_set + + +class Regression(GFNAlgorithm): + regress_to_Fsa: bool = False + loss_type: str = "MSE" + + def compute_batch_losses(self, model, batch, **kw): + if self.regress_to_Fsa: + fwd_cat, *other = model(batch, torch.zeros((batch.num_graphs, 1), device=batch.x.device)) + mask = torch.cat([i.flatten() for i in fwd_cat.masks]) + pred = torch.cat([i.flatten() for i in fwd_cat.logits]) * mask + batch.y = batch.y * mask + else: + pred = model(batch, torch.zeros((batch.num_graphs, 1), device=batch.x.device)) + if self.loss_type == "MSE": + loss = (pred - batch.y).pow(2).mean() + elif self.loss_type == "MAE": + loss = abs(pred - batch.y).mean() + else: + raise NotImplementedError + return loss, {"loss": loss} + + +class BGSupervisedTrainer(BasicGraphTaskTrainer): + def setup(self): + super().setup() + self.algo = Regression() + self.algo.loss_type = self.cfg.task.basic_graph.supervised_loss + self.algo.regress_to_Fsa = self.cfg.task.basic_graph.regress_to_Fsa + self.training_data.output_graphs = True + self.test_data.output_graphs = True + if self.cfg.task.basic_graph.regress_to_P_F: + # P_F is just the normalized Fsa, so this flag must be on + assert self.cfg.task.basic_graph.regress_to_Fsa + + for i in [self.training_data, self.test_data]: + i.compute_Fsa = self.cfg.task.basic_graph.regress_to_Fsa + i.regress_to_F = self.cfg.task.basic_graph.regress_to_Fsa + i.compute_normalized_Fsa = self.cfg.task.basic_graph.regress_to_Fsa + i.epc = self.exact_prob_cb + + def build_training_data_loader(self) -> DataLoader: + return torch.utils.data.DataLoader( + self.training_data, + batch_size=self.mb_size, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + shuffle=True, + collate_fn=self.training_data.collate_fn, + ) + + def build_validation_data_loader(self) -> DataLoader: + return torch.utils.data.DataLoader( + self.test_data, + batch_size=self.mb_size, + num_workers=self.cfg.num_workers, + persistent_workers=self.cfg.num_workers > 0, + collate_fn=self.test_data.collate_fn, + ) + + +def main(): + # Launch a test job + hps = { + "num_training_steps": 20000, + "validate_every": 100, + "num_workers": 16, + "log_dir": "./logs/basic_graphs/run_6n_14", + "model": {"num_layers": 2, "num_emb": 256}, + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10}, + # "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0}, + "algo": {"global_batch_size": 2048, "tb": {"do_subtb": False}, "max_nodes": 6}, + "task": { + "basic_graph": {"do_supervised": False, "do_tabular_model": True} + }, # Change this to launch a supervised job + } + if hps["task"]["basic_graph"]["do_supervised"]: + trial = BGSupervisedTrainer(hps, torch.device("cuda")) + else: + trial = BasicGraphTaskTrainer(hps, torch.device("cuda")) + torch.set_num_threads(1) + trial.verbose = True + trial.print_every = 1 + trial.run() + + +if __name__ == "__main__": + import sys + + if len(sys.argv) >= 3: + # Example call: + # python basic_graph_task.py --recompute-all ./data/basic_graphs 7 + if sys.argv[1] == "--recompute-all": + max_nodes = 7 if len(sys.argv) == 3 else int(sys.argv[3]) + states = load_two_col_data(sys.argv[2], max_nodes, generate_if_missing=True) + env = GraphBuildingEnv() + ctx = BasicGraphContext(max_nodes, num_cond_dim=1, graph_data=states, output_gid=True) + epc = ExactProbCompCallback( + None, states, torch.device("cpu"), ctx=ctx, env=env, do_save_px=False, log_rewards=1 + ) + epc.compute_cache() + epc.save_cache(sys.argv[2] + f"/two_col_epc_cache_{max_nodes}.pkl") + else: + raise ValueError(sys.argv) + else: + main() diff --git a/src/gflownet/tasks/config.py b/src/gflownet/tasks/config.py index a9f6ac3f..7f9ab8ab 100644 --- a/src/gflownet/tasks/config.py +++ b/src/gflownet/tasks/config.py @@ -57,8 +57,28 @@ class QM9TaskConfig: model_path: str = "./data/qm9/qm9_model.pt" +@dataclass +class BasicGraphConfig: + do_save_generated: bool = True + data_root: str = "./data/basic_graph_task" + reward_func: str = "count" # One of cliques, even_neighbors, count, const + do_supervised: bool = False + do_tabular_model: bool = False + supervised_loss: str = "MSE" + train_ratio: float = 0.9 + i2h_width: int = 4 # This is a model hyperparameter that I'm testing out here, should move to model config + # Distillation: + regress_to_F: bool = False + regress_to_Fsa: bool = False + regress_to_P_F: bool = False + # Test split + test_split_type: str = "subtrees" + test_split_seed: int = 142857 + + @dataclass class TasksConfig: qm9: QM9TaskConfig = QM9TaskConfig() seh: SEHTaskConfig = SEHTaskConfig() seh_moo: SEHMOOTaskConfig = SEHMOOTaskConfig() + basic_graph: BasicGraphConfig = BasicGraphConfig() diff --git a/src/gflownet/tasks/seh_frag.py b/src/gflownet/tasks/seh_frag.py index 4d0cc624..66afc8c0 100644 --- a/src/gflownet/tasks/seh_frag.py +++ b/src/gflownet/tasks/seh_frag.py @@ -125,7 +125,7 @@ def main(): """Example of how this model can be run outside of Determined""" hps = { "log_dir": "./logs/debug_run_seh_frag", - "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), + "device": "cuda" if torch.cuda.is_available() else "cpu", "overwrite_existing_exp": True, "num_training_steps": 10_000, "num_workers": 8, diff --git a/src/gflownet/tasks/seh_frag_moo.py b/src/gflownet/tasks/seh_frag_moo.py index 8ad73320..41fb8210 100644 --- a/src/gflownet/tasks/seh_frag_moo.py +++ b/src/gflownet/tasks/seh_frag_moo.py @@ -358,7 +358,7 @@ def main(): """Example of how this model can be run.""" hps = { "log_dir": "./logs/debug_run_sfm", - "device": torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"), + "device": "cuda" if torch.cuda.is_available() else "cpu", "pickle_mp_messages": True, "overwrite_existing_exp": True, "seed": 0, diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 93e0e0a5..254d5499 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -107,12 +107,13 @@ def __init__(self, hps: Dict[str, Any]): # `sampling_model` is used by the data workers to sample new objects from the model. Can be # the same as `model`. self.sampling_model: nn.Module - self.replay_buffer: Optional[ReplayBuffer] + self.replay_buffer: Optional[ReplayBuffer] = None self.mb_size: int self.env: GraphBuildingEnv self.ctx: GraphBuildingEnvContext self.task: GFNTask self.algo: GFNAlgorithm + self.cfg: Config # There are three sources of config values # - The default values specified in individual config classes @@ -187,7 +188,11 @@ def build_callbacks(self): def build_training_data_loader(self) -> DataLoader: model, dev = self._wrap_for_mp(self.sampling_model, send_to_device=True) - replay_buffer, _ = self._wrap_for_mp(self.replay_buffer, send_to_device=False) + replay_buffer, _ = ( + self._wrap_for_mp(self.replay_buffer, send_to_device=False) + if self.replay_buffer is not None + else (None, None) + ) iterator = SamplingIterator( self.training_data, model, From eb7c23437d57c761402d68827080d5da55bb204f Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Aug 2023 10:35:06 -0400 Subject: [PATCH 2/7] tox --- src/gflownet/algo/trajectory_balance.py | 2 +- src/gflownet/data/sampling_iterator.py | 1 - src/gflownet/tasks/basic_graph_task.py | 4 ++-- src/gflownet/trainer.py | 1 - 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 21122d2a..064f9610 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -6,7 +6,7 @@ import torch.nn as nn import torch_geometric.data as gd from torch import Tensor -from torch_scatter import scatter, scatter_sum, scatter_logsumexp +from torch_scatter import scatter, scatter_logsumexp, scatter_sum from gflownet.algo.graph_sampling import GraphSampler from gflownet.config import Config diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 3163c8fd..6d503e63 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -12,7 +12,6 @@ from rdkit import Chem, RDLogger from torch.utils.data import Dataset, IterableDataset -from gflownet.config import Config from gflownet.data.replay_buffer import ReplayBuffer from gflownet.envs.graph_building_env import GraphActionCategorical diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index a79b43e1..bf6801fc 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -1,5 +1,5 @@ -import os import bz2 +import os import pickle from typing import Dict, List, Tuple @@ -21,9 +21,9 @@ from gflownet.envs.graph_building_env import ( Graph, GraphAction, + GraphActionCategorical, GraphActionType, GraphBuildingEnv, - GraphActionCategorical, ) from gflownet.models.graph_transformer import GraphTransformer, GraphTransformerGFN from gflownet.trainer import FlatRewards, GFNAlgorithm, GFNTask, GFNTrainer, RewardScalar diff --git a/src/gflownet/trainer.py b/src/gflownet/trainer.py index 254d5499..197e197b 100644 --- a/src/gflownet/trainer.py +++ b/src/gflownet/trainer.py @@ -113,7 +113,6 @@ def __init__(self, hps: Dict[str, Any]): self.ctx: GraphBuildingEnvContext self.task: GFNTask self.algo: GFNAlgorithm - self.cfg: Config # There are three sources of config values # - The default values specified in individual config classes From 424ad7af22ea4a1354a39fea23d19b16f57bf7e0 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Tue, 8 Aug 2023 10:39:01 -0400 Subject: [PATCH 3/7] env tests --- tests/test_envs.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_envs.py b/tests/test_envs.py index 204a17cb..85c3eb38 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -10,6 +10,7 @@ from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext +from gflownet.envs.basic_graph_ctx import BasicGraphContext from gflownet.models import bengio2021flow @@ -82,6 +83,10 @@ def get_atom_env_ctx(): return MolBuildingEnvContext(atoms=["C", "N"], expl_H_range=[0], charges=[0], max_nodes=2) +def get_basic_env_ctx(): + return BasicGraphContext(max_nodes=2) + + @pytest.fixture def two_node_states_frags(request): data = request.config.cache.get("frag_env/two_node_states", None) @@ -106,6 +111,18 @@ def two_node_states_atoms(request): return data +@pytest.fixture +def two_node_states_basic(request): + data = request.config.cache.get("basic_graph_env/two_node_states", None) + if data is None: + data = build_two_node_states(get_basic_env_ctx()) + # pytest caches through JSON so we have to make a clean enough string + request.config.cache.set("basic_graph_env/two_node_states", base64.b64encode(pickle.dumps(data)).decode()) + else: + data = pickle.loads(base64.b64decode(data)) + return data + + def _test_backwards_mask_equivalence(two_node_states, ctx): """This tests that FragMolBuildingEnvContext implements backwards masks correctly. It treats GraphBuildingEnv.count_backward_transitions as the ground truth and raises an error if there is @@ -176,3 +193,11 @@ def test_backwards_mask_equivalence_atom(two_node_states_atoms): def test_backwards_mask_equivalence_ipa_atom(two_node_states_atoms): _test_backwards_mask_equivalence_ipa(two_node_states_atoms, get_atom_env_ctx()) + + +def test_backwards_mask_equivalence_basic(two_node_states_basic): + _test_backwards_mask_equivalence(two_node_states_basic, get_basic_env_ctx()) + + +def test_backwards_mask_equivalence_ipa_basic(two_node_states_basic): + _test_backwards_mask_equivalence_ipa(two_node_states_basic, get_basic_env_ctx()) From 204035614104285a117ac2e81b814c12ff4a1593 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Thu, 10 Aug 2023 11:01:52 -0400 Subject: [PATCH 4/7] fixed tabular node ordering problem --- src/gflownet/algo/trajectory_balance.py | 10 +++++-- src/gflownet/data/sampling_iterator.py | 1 + src/gflownet/envs/basic_graph_ctx.py | 37 ++++++++++++++++++++++- src/gflownet/tasks/basic_graph_task.py | 40 +++++++++++++++++++++---- 4 files changed, 79 insertions(+), 9 deletions(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 064f9610..ad82c37c 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -88,7 +88,7 @@ def __init__( ) if self.cfg.do_subtb: self._subtb_max_len = self.global_cfg.algo.max_len + 2 - self._init_subtb(torch.device("cuda")) # TODO: where are we getting device info? + self._init_subtb(torch.device(cfg.device)) # TODO: where are we getting device info? def create_training_data_from_own_samples( self, model: TrajectoryBalanceModel, n: int, cond_info: Tensor, random_action_prob: float @@ -138,7 +138,11 @@ def create_training_data_from_graphs(self, graphs): trajs: List[Dict{'traj': List[tuple[Graph, GraphAction]]}] A list of trajectories. """ - trajs = [{"traj": generate_forward_trajectory(i)} for i in graphs] + if hasattr(self.ctx, "relabel"): + relabel = self.ctx.relabel + else: + relabel = lambda *x: x + trajs = [{"traj": [relabel(*t) for t in generate_forward_trajectory(i)]} for i in graphs] for traj in trajs: n_back = [ self.env.count_backward_transitions(gp, check_idempotent=self.cfg.do_correct_idempotent) @@ -505,6 +509,6 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): P_F_sums = scatter_sum(P_F[idces + offset], dests) P_B_sums = scatter_sum(P_B[idces + offset], dests) F_start = F[offset : offset + T].repeat_interleave(T - ar[:T]) - F_end = F_and_R[fidces] + F_end = F_and_R[fidces] # .detach() total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss diff --git a/src/gflownet/data/sampling_iterator.py b/src/gflownet/data/sampling_iterator.py index 6d503e63..358d0545 100644 --- a/src/gflownet/data/sampling_iterator.py +++ b/src/gflownet/data/sampling_iterator.py @@ -334,6 +334,7 @@ def iterator(self): batch.preferences = cond_info.get("preferences", None) batch.focus_dir = cond_info.get("focus_dir", None) batch.extra_info = extra_info + batch.trajs = trajs # TODO: we could very well just pass the cond_info dict to construct_batch above, # and the algo can decide what it wants to put in the batch object diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py index fd04e5e0..c862a394 100644 --- a/src/gflownet/envs/basic_graph_ctx.py +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -1,3 +1,4 @@ +import copy from typing import Dict, List, Tuple import networkx as nx @@ -9,6 +10,7 @@ Graph, GraphAction, GraphActionType, + GraphBuildingEnv, GraphBuildingEnvContext, graph_without_edge, ) @@ -63,6 +65,7 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals GraphActionType.RemoveNode, GraphActionType.RemoveEdge, ] + self._env = GraphBuildingEnv() self.device = torch.device("cpu") self.graph_data = graph_data self.hash_to_graphs: Dict[str, int] = {} @@ -70,6 +73,25 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals states_hash = [hashg(i) for i in graph_data] for i, h, g in zip(range(len(graph_data)), states_hash, graph_data): self.hash_to_graphs[h] = self.hash_to_graphs.get(h, list()) + [(g, i)] + self._cache = {} + + def relabel(self, g, ga): + if ga.action != GraphActionType.Stop: + gp = self._env.step(g, ga) + ig = self.graph_data[self.get_graph_idx(g)] + rmap = nx.vf2pp_isomorphism(g, ig, "v") + ga = copy.copy(ga) + if rmap is None and not len(g): + rmap = {0: 0} + if ga.source is not None: + ga.source = rmap[ga.source] + if ga.target is not None: + ga.target = rmap[ga.target] + if ga.action != GraphActionType.Stop: + gp2 = self._env.step(ig, ga) + if not nx.is_isomorphic(gp2, gp, lambda a, b: a == b): + raise ValueError() + return copy.deepcopy(ig), ga def get_graph_idx(self, g, default=None): h = hashg(g) @@ -138,6 +160,16 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" + if self.graph_data is not None: + # This caching achieves two things, first we'll speed things up + gidx = self.get_graph_idx(g) + if gidx in self._cache: + return self._cache[gidx] + # And second we'll always have the same node ordering, which is necessary for the tabular model + # to work. In the non-tabular case, we're hopefully using a model that's invariant to node ordering, so this + # shouldn't cause any problems. + g = self.graph_data[gidx] + x = torch.zeros((max(1, len(g.nodes)), self.num_node_dim - self._num_rw_feat)) x[0, -1] = len(g.nodes) == 0 remove_node_mask = torch.zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0) @@ -160,7 +192,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: non_edge_index = torch.tensor([i for i in gc.edges], dtype=torch.long).reshape((-1, 2)).T gid = self.get_graph_idx(g) if self.output_gid else 0 - return self._preprocess( + data = self._preprocess( gd.Data( x, edge_index, @@ -174,6 +206,9 @@ def graph_to_Data(self, g: Graph) -> gd.Data: gid=gid, ) ) + if self.graph_data is not None: + self._cache[gidx] = data + return data def _preprocess(self, g: gd.Data) -> gd.Data: if self._num_rw_feat > 0: diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index bf6801fc..4f961f69 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -324,10 +324,12 @@ def setup(self): ) elif mcfg.do_tabular_model: model = TabularHashingModel(self.exact_prob_cb) - if 1: + if 0: + model.set_values(self.exact_prob_cb) + if 0: # reload_bit model.load_state_dict( torch.load( - "/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_9/model_state.pt" + "/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_4/model_state.pt" )["models_state_dict"][0] ) else: @@ -353,6 +355,8 @@ def setup(self): self.opt = torch.optim.SGD( params, self.cfg.opt.learning_rate, self.cfg.opt.momentum, weight_decay=self.cfg.opt.weight_decay ) + elif self.cfg.opt.opt == "RMSProp": + self.opt = torch.optim.RMSprop(params, self.cfg.opt.learning_rate, weight_decay=self.cfg.opt.weight_decay) self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lambda steps: 2 ** (-steps / self.cfg.opt.lr_decay)) algo = self.cfg.algo.method @@ -481,6 +485,16 @@ def __call__(self, g: gd.Batch, cond_info): logF_s, ) + def set_values(self, epc): + """Set the values of the table to the true values of the MDP. This tabular model should have 0 error.""" + for i in tqdm(range(len(epc.states))): + for neighbor in list(epc.mdp_graph.neighbors(i)): + for _, edge in epc.mdp_graph.get_edge_data(i, neighbor).items(): + a, F = edge["a"], edge["F"] + self.table.data[self.slices[i][a[0]] + a[1] * self.shapes[i][a[0]][1] + a[2]] = F + self.table.data[self.slices[i][3]] = epc.mdp_graph.nodes[i]["F"] + self._logZ.data = torch.tensor(epc.mdp_graph.nodes[0]["F"]).float() + def logZ(self, cond_info: Tensor): return self._logZ.tile(cond_info.shape[0]).reshape((-1, 1)) # Why is the reshape necessary? @@ -876,15 +890,31 @@ def main(): "num_training_steps": 20000, "validate_every": 100, "num_workers": 16, - "log_dir": "./logs/basic_graphs/run_6n_14", + "log_dir": "./logs/basic_graphs/run_6n_19", "model": {"num_layers": 2, "num_emb": 256}, - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10}, + # WARNING: SubTB is detached targets! + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.995, "lr_decay": 1e10}, + # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, # "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0}, - "algo": {"global_batch_size": 2048, "tb": {"do_subtb": False}, "max_nodes": 6}, + "algo": {"global_batch_size": 4096, "tb": {"do_subtb": True}, "max_nodes": 6}, "task": { "basic_graph": {"do_supervised": False, "do_tabular_model": True} }, # Change this to launch a supervised job } + + hps = { + "num_training_steps": 20000, + "validate_every": 100, + "num_workers": 16, + "log_dir": "./logs/basic_graphs/run_6n_27", + "model": {"num_layers": 2, "num_emb": 256}, + # WARNING: SubTB is detached targets! -- not + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, + # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, + # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, + "algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4}, + "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, # + } if hps["task"]["basic_graph"]["do_supervised"]: trial = BGSupervisedTrainer(hps, torch.device("cuda")) else: From b37b01e2bf76e97ed6ad56db5d04b3c60786a94f Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 16 Aug 2023 11:20:07 -0400 Subject: [PATCH 5/7] fix broken merge --- src/gflownet/envs/basic_graph_ctx.py | 3 ++- src/gflownet/tasks/basic_graph_task.py | 23 +++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py index c862a394..362efb42 100644 --- a/src/gflownet/envs/basic_graph_ctx.py +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -43,6 +43,7 @@ def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=Fals "v": [0, 1], # Imagine this is as colors } self._num_rw_feat = 8 + self.not_a_molecule_env = True self.num_new_node_values = len(self.node_attr_values["v"]) self.num_node_attr_logits = None @@ -160,7 +161,7 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int def graph_to_Data(self, g: Graph) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - if self.graph_data is not None: + if self.graph_data is not None and False: # This caching achieves two things, first we'll speed things up gidx = self.get_graph_idx(g) if gidx in self._cache: diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index 4f961f69..849162f4 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -915,10 +915,29 @@ def main(): "algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4}, "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, # } + + hps = { + "num_training_steps": 20000, + "validate_every": 100, + "num_workers": 0, + "log_dir": "./logs/basic_graphs/run_6n_pb2", + "model": {"num_layers": 2, "num_emb": 256}, + # WARNING: SubTB is detached targets! -- not + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, + # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, + # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, + "algo": { + "global_batch_size": 512, + "tb": {"do_subtb": True, "do_parameterize_p_b": False}, + "max_nodes": 6, + "offline_ratio": 0 / 4, + }, + "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": False, "train_ratio": 1}}, # + } if hps["task"]["basic_graph"]["do_supervised"]: - trial = BGSupervisedTrainer(hps, torch.device("cuda")) + trial = BGSupervisedTrainer(hps) else: - trial = BasicGraphTaskTrainer(hps, torch.device("cuda")) + trial = BasicGraphTaskTrainer(hps) torch.set_num_threads(1) trial.verbose = True trial.print_every = 1 From ba77d4a4a5f1c2604dfd44a800e58c2171eae842 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 6 Sep 2023 12:47:51 -0400 Subject: [PATCH 6/7] minor fixes --- src/gflownet/algo/trajectory_balance.py | 2 +- src/gflownet/envs/basic_graph_ctx.py | 7 +-- src/gflownet/tasks/basic_graph_task.py | 63 ++++--------------------- 3 files changed, 13 insertions(+), 59 deletions(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 2035081a..593c7b3b 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -577,7 +577,7 @@ def subtb_loss_fast(self, P_F, P_B, F, R, traj_lengths): P_F_sums = scatter_sum(P_F[idces + offset], dests) P_B_sums = scatter_sum(P_B[idces + offset], dests) F_start = F[offset : offset + T].repeat_interleave(T - ar[:T]) - F_end = F_and_R[fidces] # .detach() + F_end = F_and_R[fidces] total_loss[ep] = (F_start - F_end + P_F_sums - P_B_sums).pow(2).sum() / car[T] return total_loss diff --git a/src/gflownet/envs/basic_graph_ctx.py b/src/gflownet/envs/basic_graph_ctx.py index 362efb42..1fce3801 100644 --- a/src/gflownet/envs/basic_graph_ctx.py +++ b/src/gflownet/envs/basic_graph_ctx.py @@ -38,6 +38,7 @@ class BasicGraphContext(GraphBuildingEnvContext): def __init__(self, max_nodes=7, num_cond_dim=0, graph_data=None, output_gid=False): self.max_nodes = max_nodes self.output_gid = output_gid + self.use_graph_cache = False self.node_attr_values = { "v": [0, 1], # Imagine this is as colors @@ -159,9 +160,9 @@ def GraphAction_to_aidx(self, g: gd.Data, action: GraphAction) -> Tuple[int, int type_idx = self.bck_action_type_order.index(action.action) return (type_idx, int(row), int(col)) - def graph_to_Data(self, g: Graph) -> gd.Data: + def graph_to_Data(self, g: Graph, t: int = 0) -> gd.Data: """Convert a networkx Graph to a torch geometric Data instance""" - if self.graph_data is not None and False: + if self.graph_data is not None and self.use_graph_cache: # This caching achieves two things, first we'll speed things up gidx = self.get_graph_idx(g) if gidx in self._cache: @@ -207,7 +208,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data: gid=gid, ) ) - if self.graph_data is not None: + if self.graph_data is not None and self.use_graph_cache: self._cache[gidx] = data return data diff --git a/src/gflownet/tasks/basic_graph_task.py b/src/gflownet/tasks/basic_graph_task.py index 849162f4..36df986e 100644 --- a/src/gflownet/tasks/basic_graph_task.py +++ b/src/gflownet/tasks/basic_graph_task.py @@ -14,6 +14,7 @@ from torch_scatter import scatter_logsumexp from tqdm import tqdm +from gflownet.algo.config import TBVariant from gflownet.algo.flow_matching import FlowMatching from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config @@ -272,7 +273,7 @@ def set_default_hps(self, cfg: Config): cfg.model.num_layers = 8 cfg.algo.valid_offline_ratio = 0 cfg.algo.tb.do_correct_idempotent = True # Important to converge to the true p(x) - cfg.algo.tb.do_subtb = True + cfg.algo.tb.variant = TBVariant.SubTB1 cfg.algo.tb.do_parameterize_p_b = False cfg.algo.illegal_action_logreward = -30 # Although, all states are legal here, this shouldn't matter cfg.num_workers = 8 @@ -293,6 +294,7 @@ def setup(self): self.env = GraphBuildingEnv() self._data = load_two_col_data(self.cfg.task.basic_graph.data_root, max_nodes=max_nodes) self.ctx = BasicGraphContext(max_nodes, num_cond_dim=1, graph_data=self._data, output_gid=True) + self.ctx.use_graph_cache = mcfg.do_tabular_model self._do_supervised = self.cfg.task.basic_graph.do_supervised self.training_data = TwoColorGraphDataset( @@ -326,12 +328,6 @@ def setup(self): model = TabularHashingModel(self.exact_prob_cb) if 0: model.set_values(self.exact_prob_cb) - if 0: # reload_bit - model.load_state_dict( - torch.load( - "/mnt/ps/home/CORP/emmanuel.bengio/rs/gfn/gflownet/src/gflownet/tasks/logs/basic_graphs/run_6n_4/model_state.pt" - )["models_state_dict"][0] - ) else: model = GraphTransformerGFN( self.ctx, @@ -362,7 +358,6 @@ def setup(self): algo = self.cfg.algo.method if algo == "TB" or algo == "subTB": self.algo = TrajectoryBalance(self.env, self.ctx, self.rng, self.cfg) - self.algo.graph_sampler.sample_temp = 100 elif algo == "FM": self.algo = FlowMatching(self.env, self.ctx, self.rng, self.cfg) self.task = BasicGraphTask( @@ -611,18 +606,8 @@ def compute_cache(self, tqdm_disable=None): bs = states[bi : bi + mbs] bD = states_Data[bi : bi + mbs] indices = list(range(bi, bi + len(bs))) - # TODO: if the environment's masks are well designed, this non_terminal business shouldn't be necessary - # non_terminals = [(i, j, k) for i, j, k in zip(bs, bD, indices) if not self.is_terminal(i)] - # if not len(non_terminals): - # self.precomputed_batches.append(None) - # self.precomputed_indices.append(None) - # continue - # bs, bD, indices = zip(*non_terminals) batch = self.ctx.collate(bD).to(dev) self.precomputed_batches.append(batch) - - # with torch.no_grad(): - # cat, *_, mo = self.trial.model(batch, ones[:len(bs)]) actions = [[] for i in range(len(bs))] offset = 0 for u, i in enumerate(ctx.action_type_order): @@ -752,11 +737,11 @@ def get_bck_trajectory_test_split(self, r, seed=142857): while len(test_set) < n: i0 = np.random.randint(len(self.states)) s0 = self.states[i0] - if len(s0.nodes) < 7: # TODO: unhardcode this + if len(s0.nodes) < 7: # TODO: unhardcode this? continue s = s0 idx = i0 - while len(s.nodes) > 5: # TODO: unhardcode this + while len(s.nodes) > 5: # TODO: unhardcode this? test_set.add(idx) actions = [ (u, a.item(), b.item()) @@ -886,35 +871,6 @@ def build_validation_data_loader(self) -> DataLoader: def main(): # Launch a test job - hps = { - "num_training_steps": 20000, - "validate_every": 100, - "num_workers": 16, - "log_dir": "./logs/basic_graphs/run_6n_19", - "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.995, "lr_decay": 1e10}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 0.3, "momentum": 0}, - "algo": {"global_batch_size": 4096, "tb": {"do_subtb": True}, "max_nodes": 6}, - "task": { - "basic_graph": {"do_supervised": False, "do_tabular_model": True} - }, # Change this to launch a supervised job - } - - hps = { - "num_training_steps": 20000, - "validate_every": 100, - "num_workers": 16, - "log_dir": "./logs/basic_graphs/run_6n_27", - "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! -- not - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, - "algo": {"global_batch_size": 512, "tb": {"do_subtb": True}, "max_nodes": 6, "offline_ratio": 1 / 4}, - "task": {"basic_graph": {"do_supervised": False, "do_tabular_model": True, "train_ratio": 1}}, # - } hps = { "num_training_steps": 20000, @@ -922,13 +878,10 @@ def main(): "num_workers": 0, "log_dir": "./logs/basic_graphs/run_6n_pb2", "model": {"num_layers": 2, "num_emb": 256}, - # WARNING: SubTB is detached targets! -- not - "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2}, - # "opt": {"adam_eps": 1e-8, "learning_rate": 3e-2, "momentum": 0.0, "lr_decay": 1e10, "opt": "RMSProp"}, - # "opt": {"opt": "SGD", "learning_rate": 1e-2, "momentum": 0}, + "opt": {"adam_eps": 1e-8, "learning_rate": 3e-4}, "algo": { - "global_batch_size": 512, - "tb": {"do_subtb": True, "do_parameterize_p_b": False}, + "global_batch_size": 64, + "tb": {"variant": "SubTB1", "do_parameterize_p_b": False}, "max_nodes": 6, "offline_ratio": 0 / 4, }, From fc70f2de84b590744130ff80490404b256f092a8 Mon Sep 17 00:00:00 2001 From: Emmanuel Bengio Date: Wed, 6 Sep 2023 12:53:27 -0400 Subject: [PATCH 7/7] ruff --- src/gflownet/algo/trajectory_balance.py | 2 +- tests/test_envs.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gflownet/algo/trajectory_balance.py b/src/gflownet/algo/trajectory_balance.py index 593c7b3b..00e65522 100644 --- a/src/gflownet/algo/trajectory_balance.py +++ b/src/gflownet/algo/trajectory_balance.py @@ -190,7 +190,7 @@ def create_training_data_from_graphs(self, graphs): if hasattr(self.ctx, "relabel"): relabel = self.ctx.relabel else: - relabel = lambda *x: x + relabel = lambda *x: x # noqa: E731 trajs = [{"traj": [relabel(*t) for t in generate_forward_trajectory(i)]} for i in graphs] for traj in trajs: n_back = [ diff --git a/tests/test_envs.py b/tests/test_envs.py index 85c3eb38..1bc639e7 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,10 +7,10 @@ from gflownet.algo.trajectory_balance import TrajectoryBalance from gflownet.config import Config +from gflownet.envs.basic_graph_ctx import BasicGraphContext from gflownet.envs.frag_mol_env import FragMolBuildingEnvContext from gflownet.envs.graph_building_env import GraphBuildingEnv from gflownet.envs.mol_building_env import MolBuildingEnvContext -from gflownet.envs.basic_graph_ctx import BasicGraphContext from gflownet.models import bengio2021flow