Skip to content

Commit

Permalink
invalidate cache
Browse files Browse the repository at this point in the history
  • Loading branch information
bengioe committed Feb 26, 2024
1 parent 466ace4 commit 73c760b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/gflownet/envs/frag_mol_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,15 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
if hasattr(g, "_Data_cache") and g._Data_cache is not None:
return g._Data_cache
zeros = lambda x: np.zeros(x, dtype=np.float32) # noqa: E731
ones = lambda x: np.ones(x, dtype=np.float32) # noqa: E731
x = zeros((max(1, len(g.nodes)), self.num_node_dim))
x[0, -1] = len(g.nodes) == 0
edge_attr = zeros((len(g.edges) * 2, self.num_edge_dim))
set_edge_attr_mask = zeros((len(g.edges), self.num_edge_attr_logits))
# TODO: This is a bit silly but we have to do +1 when the graph is empty because the default
# padding action is a [0, 0, 0], which needs to be legal for the empty state. Should be
# fixable with a bit of smarts & refactoring.
remove_node_mask = zeros((x.shape[0], 1)) + (1 if len(g) == 0 else 0)
remove_node_mask = ones((x.shape[0], 1)) if len(g) == 0 else zeros((x.shape[0], 1))
remove_edge_attr_mask = zeros((len(g.edges), self.num_edge_attrs))
if len(g):
degrees = np.array(list(g.degree), dtype=np.int32)[:, 1] # type: ignore
Expand Down Expand Up @@ -266,7 +267,7 @@ def graph_to_Data(self, g: Graph) -> gd.Data:
else np.ones((1, 1), np.float32)
)
add_node_mask = add_node_mask * np.ones((x.shape[0], self.num_new_node_values), np.float32)
stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else np.ones((1, 1), np.float32)
stop_mask = zeros((1, 1)) if has_unfilled_attach or not len(g) else ones((1, 1))

data = gd.Data(
**{
Expand Down
4 changes: 4 additions & 0 deletions src/gflownet/envs/graph_building_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def bridges(self):
def relabel_nodes(self, rmap):
return nx.relabel_nodes(self, rmap)

def clear_cache(self):
self._Data_cache = None


def graph_without_edge(g, e):
gp = g.copy()
Expand Down Expand Up @@ -220,6 +223,7 @@ def step(self, g: Graph, action: GraphAction) -> Graph:
else:
raise ValueError(f"Unknown action type {action.action}", action.action)

gp.clear_cache() # Invalidate cached properties since we've modified the graph
return gp

def parents(self, g: Graph):
Expand Down

0 comments on commit 73c760b

Please sign in to comment.