Skip to content

Commit

Permalink
Merge pull request #253 from GFNOrg/graph-state-pyg
Browse files Browse the repository at this point in the history
switch from TensorDict to torch_geometric
  • Loading branch information
josephdviviano authored Mar 5, 2025
2 parents f5b8c98 + 50b3cda commit 4abeadd
Show file tree
Hide file tree
Showing 27 changed files with 1,451 additions and 1,212 deletions.
19 changes: 11 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
requires = ["poetry-core>=1.0.8"]
build-backend = "poetry.core.masonry.api"

[project]
name = "torchgfn"

[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
Expand All @@ -25,8 +22,9 @@ classifiers = [
einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
tensordict = ">=0.6.1"
torch = ">=2.6.0"
tensordict = ">=0.6.1"
torch_geometric = ">=2.6.1"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand All @@ -49,7 +47,6 @@ wandb = { version = "*", optional = true }
scikit-learn = {version = "*", optional = true }
scipy = { version = "*", optional = true }
matplotlib = { version = "*", optional = true }
torch_geometric = { version = ">=2.6.1", optional = true }

[tool.poetry.extras]
dev = [
Expand All @@ -66,7 +63,6 @@ dev = [
"sphinx",
"tox",
"flake8",
"torch_geometric",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"]
Expand All @@ -89,7 +85,6 @@ all = [
"tox",
"tqdm",
"wandb",
"torch_geometric",
]

[tool.poetry.urls]
Expand All @@ -103,17 +98,22 @@ include = '\.pyi?$'
extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g'''

[tool.pyright]
pythonVersion = "3.10"
include = ["src/gfn", "tutorials/examples", "testing"] # Removed ** globstars
exclude = [
"**/node_modules",
"**/__pycache__",
"**/.*", # Exclude dot files and folders
]

strict = [

]
# This is required as the CI pre-commit does not dl the module (i.e. numpy)
# Therefore, we have to ignore missing imports
# Removed "strict": [], as it's redundant with typeCheckingMode

typeCheckingMode = "basic"
pythonVersion = "3.10"

# Removed enableTypeIgnoreComments, not available in pyproject.toml, and bad practice.

Expand All @@ -128,6 +128,9 @@ reportUntypedFunctionDecorator = "none"
reportMissingTypeStubs = false
reportUnboundVariable = "warning"
reportGeneralTypeIssues = "none"
reportAttributeAccessIssue = false

[tool.pytest.ini_options]
reportOptionalMemberAccess = "error"
reportArgumentType = "error" #This setting doesn't exist, removed.

Expand Down
32 changes: 20 additions & 12 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def __init__(self, tensor: TensorDict):
Args:
action: a GraphActionType indicating the type of action.
features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type.
In case of EXIT action, this can be None.
features: a tensor of shape (batch_shape, feature_shape) representing the features
of the nodes or of the edges, depending on the action type. In case of EXIT
action, this can be None.
edge_index: an tensor of shape (batch_shape, 2) representing the edge to add.
This must defined if and only if the action type is GraphActionType.AddEdge.
"""
Expand Down Expand Up @@ -245,12 +246,16 @@ def __len__(self) -> int:
"""Returns the number of actions in the batch."""
return int(prod(self.batch_shape))

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions:
def __getitem__(
self, index: int | List[int] | List[bool] | slice | torch.Tensor
) -> GraphActions:
"""Get particular actions of the batch."""
return GraphActions(self.tensor[index])

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], action: GraphActions
self,
index: int | List[int] | List[bool] | slice | torch.Tensor,
action: GraphActions,
) -> None:
"""Set particular actions of the batch."""
self.tensor[index] = action.tensor
Expand All @@ -263,15 +268,18 @@ def compare(self, other: GraphActions) -> torch.Tensor:
Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
compare = torch.all(self.tensor == other.tensor, dim=-1)
return (
compare["action_type"]
& (compare["action_type"] == GraphActionType.EXIT | compare["features"])
& (
compare["action_type"]
!= GraphActionType.ADD_EDGE | compare["edge_index"]
)
action_compare = torch.all(
self.tensor["action_type"] == other.tensor["action_type"]
)
exit_compare = (
torch.all(self.tensor["features"] == other.tensor["features"])
| action_compare
== GraphActionType.EXIT
)
edge_compare = (action_compare != GraphActionType.ADD_EDGE) | (
torch.all(self.tensor["edge_index"] == other.tensor["edge_index"])
)
return action_compare & exit_compare & edge_compare

@property
def is_exit(self) -> torch.Tensor:
Expand Down
30 changes: 24 additions & 6 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class Trajectories(Container):
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the
trajectories' actions.
"""

Expand All @@ -57,7 +58,8 @@ def __init__(
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of
the trajectories' actions.
estimator_outputs: Tensor of shape (batch_shape, output_dim).
When forward sampling off-policy for an n-step trajectory,
n forward passes will be made on some function approximator,
Expand Down Expand Up @@ -103,25 +105,37 @@ def __init__(
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
), f"log_probs.shape={log_probs.shape}, "
f"self.max_length={self.max_length}, "
f"self.n_trajectories={self.n_trajectories}"
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs: torch.Tensor = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
# assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape TODO: check why fails
# TODO: check why this fails.
# assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape
assert self.estimator_outputs.dtype == torch.float

def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
assert states.ndim == 3
trajectories_representation = ""
assert isinstance(
self.env.s0, torch.Tensor
), "not supported for Graph trajectories."
assert isinstance(
self.env.sf, torch.Tensor
), "not supported for Graph trajectories."

for traj in states[:10]:
one_traj_repr = []
for step in traj:
one_traj_repr.append(str(step.cpu().numpy()))
if step.equal(self.env.s0 if self.is_backward else self.env.sf):
if self.is_backward and step.equal(self.env.s0):
break
elif not self.is_backward and step.equal(self.env.sf):
break
trajectories_representation += "-> ".join(one_traj_repr) + "\n"
return (
Expand Down Expand Up @@ -482,6 +496,10 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories:
new_actions[torch.arange(len(self)), seq_lengths] = self.env.exit_action

# Assign reversed states to new_states
assert isinstance(states[:, -1], torch.Tensor)
assert isinstance(
self.env.s0, torch.Tensor
), "reverse_backward_trajectories not supported for Graph trajectories"
assert torch.all(states[:, -1] == self.env.s0), "Last state must be s0"
new_states[:, 0] = self.env.s0
new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]]
Expand Down Expand Up @@ -536,7 +554,7 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories:


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
"""Pads tensor a to match the dimension of b."""
assert a.shape[0] < target_dim0, "a is already larger than target_dim0!"
pad_dim = target_dim0 - a.shape[0]
pad_dim_full = (pad_dim,) + tuple(a.shape[1:])
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(
the children of the transitions.
is_backward: Whether the transitions are backward transitions (i.e.
`next_states` is the parent of states).
log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a default value like
`-float('inf')` for non-terminating transitions).
log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a
default value like `-float('inf')` for non-terminating transitions).
log_probs: Tensor of shape (n_transitions,) containing the log-probabilities of the actions.
Raises:
Expand Down
52 changes: 19 additions & 33 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional, Tuple, Union, cast

import torch
from tensordict import TensorDict
from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.data import Data as GeometricData

from gfn.actions import Actions, GraphActions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
Expand All @@ -23,12 +24,12 @@ class Env(ABC):

def __init__(
self,
s0: torch.Tensor | TensorDict,
s0: torch.Tensor | GeometricData,
state_shape: Tuple,
action_shape: Tuple,
dummy_action: torch.Tensor,
exit_action: torch.Tensor,
sf: Optional[torch.Tensor | TensorDict] = None,
sf: Optional[torch.Tensor | GeometricData] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
Expand All @@ -51,7 +52,7 @@ def __init__(
"""
self.device = get_device(device_str, default_device=s0.device)

self.s0 = s0.to(self.device)
self.s0 = s0.to(self.device) # pyright: ignore
assert s0.shape == state_shape
if sf is None:
sf = torch.full(s0.shape, -float("inf")).to(self.device)
Expand Down Expand Up @@ -229,20 +230,7 @@ def reset(
batch_shape=batch_shape, random=random, sink=sink
)

def validate_actions(
self, states: States, actions: Actions, backward: bool = False
) -> bool:
"""First, asserts that states and actions have the same batch_shape.
Then, uses `is_action_valid`.
Returns a boolean indicating whether states/actions pairs are valid."""
assert states.batch_shape == actions.batch_shape
return self.is_action_valid(states, actions, backward)

def _step(
self,
states: States,
actions: Actions,
) -> States:
def _step(self, states: States, actions: Actions) -> States:
"""Core step function. Calls the user-defined self.step() function.
Function that takes a batch of states and actions and returns a batch of next
Expand All @@ -256,7 +244,7 @@ def _step(
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]

if not self.validate_actions(valid_states, valid_actions):
if not self.is_action_valid(valid_states, valid_actions):
raise NonValidActionsError(
"Some actions are not valid in the given states. See `is_action_valid`."
)
Expand All @@ -275,19 +263,15 @@ def _step(

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)

if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)):
if not isinstance(new_not_done_states_tensor, (torch.Tensor, GeometricBatch)):
raise Exception(
"User implemented env.step function *must* return a torch.Tensor!"
)

new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor)
return new_states

def _backward_step(
self,
states: States,
actions: Actions,
) -> States:
def _backward_step(self, states: States, actions: Actions) -> States:
"""Core backward_step function. Calls the user-defined self.backward_step fn.
This function takes a batch of states and actions and returns a batch of next
Expand All @@ -301,7 +285,7 @@ def _backward_step(
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]

if not self.validate_actions(valid_states, valid_actions, backward=True):
if not self.is_action_valid(valid_states, valid_actions, backward=True):
raise NonValidActionsError(
"Some actions are not valid in the given states. See `is_action_valid`."
)
Expand Down Expand Up @@ -590,12 +574,12 @@ def terminating_states(self) -> DiscreteStates:
class GraphEnv(Env):
"""Base class for graph-based environments."""

sf: TensorDict # this tells the type checker that sf is a TensorDict
sf: GeometricData # this tells the type checker that sf is a GeometricData

def __init__(
self,
s0: TensorDict,
sf: TensorDict,
s0: GeometricData,
sf: GeometricData,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
Expand All @@ -604,14 +588,16 @@ def __init__(
Args:
s0: The initial graph state.
sf: The sink graph state.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
device_str: String representation of the device.
preprocessor: a Preprocessor object that converts raw graph states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
"""
self.s0 = s0.to(device_str)
self.features_dim = s0["node_feature"].shape[-1]
device = get_device(device_str, default_device=s0.device)
assert s0.x is not None

self.s0 = s0.to(device) # pyright: ignore
self.features_dim = s0.x.shape[-1]
self.sf = sf

self.States = self.make_states_class()
Expand Down
2 changes: 2 additions & 0 deletions src/gfn/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .box import Box
from .discrete_ebm import DiscreteEBM
from .graph_building import GraphBuilding
from .hypergrid import HyperGrid
from .line import Line

Expand All @@ -8,4 +9,5 @@
"DiscreteEBM",
"HyperGrid",
"Line",
"GraphBuilding",
]
Loading

0 comments on commit 4abeadd

Please sign in to comment.