Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch from TensorDict to torch_geometric #253

Merged
merged 34 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
0a1f8e1
switch from torch_geoemtric to TensorDict
younik Feb 26, 2025
382be45
update train_graph_ring
younik Feb 27, 2025
939bdab
fix all problems
younik Feb 27, 2025
0f57485
remove node_index
younik Feb 27, 2025
911157c
move dep
younik Feb 27, 2025
313562e
Merge branch 'graph-states' into graph-states
younik Feb 27, 2025
8806cda
fix pyproject.toml
hyeok9855 Feb 28, 2025
ce287eb
change state type hinting from Tensordict to torch_geometric Data
hyeok9855 Feb 28, 2025
7607945
add settings that achieve 95% in the ring generation (directed) with …
hyeok9855 Feb 28, 2025
e918ac7
rename test_state to test_graph_states
hyeok9855 Mar 3, 2025
fb4445c
renamed Batch to GeometricBatch, Data to GeometricData
josephdviviano Mar 3, 2025
461fa2d
renamed Batch to GeometricBatch, Data to GeometricData
josephdviviano Mar 3, 2025
0af5747
docstring
josephdviviano Mar 4, 2025
e717d04
fixed imports
josephdviviano Mar 4, 2025
be148ae
merge
josephdviviano Mar 4, 2025
3372702
Merge pull request #250 from younik/graph-states
hyeok9855 Mar 4, 2025
3a3e41d
Merge branch 'graph-state-pyg' of github.com:GFNOrg/torchgfn into for…
josephdviviano Mar 4, 2025
8e496d8
batch_shape as tuple for all Graph things
hyeok9855 Mar 4, 2025
2b076e0
remove since it's redundant with
hyeok9855 Mar 4, 2025
eb7c9ca
apply black
hyeok9855 Mar 4, 2025
51518ab
apply black
hyeok9855 Mar 4, 2025
f90fa7f
resolve flake-8 issues
hyeok9855 Mar 4, 2025
67349b6
fix ring exp
hyeok9855 Mar 4, 2025
2e9b035
merge conflicts resolves
josephdviviano Mar 4, 2025
96394d7
black
josephdviviano Mar 4, 2025
e1fad74
black pyright and formatting
josephdviviano Mar 4, 2025
5105837
black pyright and formatting
josephdviviano Mar 4, 2025
93ebcc1
removed lines
josephdviviano Mar 4, 2025
7da0ab5
saved changes
josephdviviano Mar 4, 2025
310e8d7
massive merge of pyright / black / testing errors -- one outstanding …
josephdviviano Mar 4, 2025
06f87d2
flake
josephdviviano Mar 5, 2025
b8a1aec
isort
josephdviviano Mar 5, 2025
efcb28b
isort
josephdviviano Mar 5, 2025
50b3cda
black
josephdviviano Mar 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
requires = ["poetry-core>=1.0.8"]
build-backend = "poetry.core.masonry.api"

[project]
name = "torchgfn"

[tool.poetry]
name = "torchgfn"
packages = [{include = "gfn", from = "src"}]
Expand All @@ -25,8 +22,9 @@ classifiers = [
einops = ">=0.6.1"
numpy = ">=1.21.2"
python = "^3.10"
tensordict = ">=0.6.1"
torch = ">=2.6.0"
tensordict = ">=0.6.1"
torch_geometric = ">=2.6.1"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand All @@ -49,7 +47,6 @@ wandb = { version = "*", optional = true }
scikit-learn = {version = "*", optional = true }
scipy = { version = "*", optional = true }
matplotlib = { version = "*", optional = true }
torch_geometric = { version = ">=2.6.1", optional = true }

[tool.poetry.extras]
dev = [
Expand All @@ -66,7 +63,6 @@ dev = [
"sphinx",
"tox",
"flake8",
"torch_geometric",
]

scripts = ["tqdm", "wandb", "scikit-learn", "scipy", "matplotlib"]
Expand All @@ -89,7 +85,6 @@ all = [
"tox",
"tqdm",
"wandb",
"torch_geometric",
]

[tool.poetry.urls]
Expand All @@ -103,17 +98,22 @@ include = '\.pyi?$'
extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g'''

[tool.pyright]
pythonVersion = "3.10"
include = ["src/gfn", "tutorials/examples", "testing"] # Removed ** globstars
exclude = [
"**/node_modules",
"**/__pycache__",
"**/.*", # Exclude dot files and folders
]

strict = [

]
# This is required as the CI pre-commit does not dl the module (i.e. numpy)
# Therefore, we have to ignore missing imports
# Removed "strict": [], as it's redundant with typeCheckingMode

typeCheckingMode = "basic"
pythonVersion = "3.10"

# Removed enableTypeIgnoreComments, not available in pyproject.toml, and bad practice.

Expand All @@ -128,6 +128,9 @@ reportUntypedFunctionDecorator = "none"
reportMissingTypeStubs = false
reportUnboundVariable = "warning"
reportGeneralTypeIssues = "none"
reportAttributeAccessIssue = false

[tool.pytest.ini_options]
reportOptionalMemberAccess = "error"
reportArgumentType = "error" #This setting doesn't exist, removed.

Expand Down
32 changes: 20 additions & 12 deletions src/gfn/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,9 @@ def __init__(self, tensor: TensorDict):

Args:
action: a GraphActionType indicating the type of action.
features: a tensor of shape (batch_shape, feature_shape) representing the features of the nodes or of the edges, depending on the action type.
In case of EXIT action, this can be None.
features: a tensor of shape (batch_shape, feature_shape) representing the features
of the nodes or of the edges, depending on the action type. In case of EXIT
action, this can be None.
edge_index: an tensor of shape (batch_shape, 2) representing the edge to add.
This must defined if and only if the action type is GraphActionType.AddEdge.
"""
Expand Down Expand Up @@ -245,12 +246,16 @@ def __len__(self) -> int:
"""Returns the number of actions in the batch."""
return int(prod(self.batch_shape))

def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> GraphActions:
def __getitem__(
self, index: int | List[int] | List[bool] | slice | torch.Tensor
) -> GraphActions:
"""Get particular actions of the batch."""
return GraphActions(self.tensor[index])

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], action: GraphActions
self,
index: int | List[int] | List[bool] | slice | torch.Tensor,
action: GraphActions,
) -> None:
"""Set particular actions of the batch."""
self.tensor[index] = action.tensor
Expand All @@ -263,15 +268,18 @@ def compare(self, other: GraphActions) -> torch.Tensor:

Returns: boolean tensor of shape batch_shape indicating whether the actions are equal.
"""
compare = torch.all(self.tensor == other.tensor, dim=-1)
return (
compare["action_type"]
& (compare["action_type"] == GraphActionType.EXIT | compare["features"])
& (
compare["action_type"]
!= GraphActionType.ADD_EDGE | compare["edge_index"]
)
action_compare = torch.all(
self.tensor["action_type"] == other.tensor["action_type"]
)
exit_compare = (
torch.all(self.tensor["features"] == other.tensor["features"])
| action_compare
== GraphActionType.EXIT
)
edge_compare = (action_compare != GraphActionType.ADD_EDGE) | (
torch.all(self.tensor["edge_index"] == other.tensor["edge_index"])
)
return action_compare & exit_compare & edge_compare

@property
def is_exit(self) -> torch.Tensor:
Expand Down
30 changes: 24 additions & 6 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ class Trajectories(Container):
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the
trajectories' actions.

"""

Expand All @@ -57,7 +58,8 @@ def __init__(
when_is_done: Tensor of shape (n_trajectories,) indicating the time step at which each trajectory ends.
is_backward: Whether the trajectories are backward or forward.
log_rewards: Tensor of shape (n_trajectories,) containing the log rewards of the trajectories.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions.
log_probs: Tensor of shape (max_length, n_trajectories) indicating the log probabilities of
the trajectories' actions.
estimator_outputs: Tensor of shape (batch_shape, output_dim).
When forward sampling off-policy for an n-step trajectory,
n forward passes will be made on some function approximator,
Expand Down Expand Up @@ -103,25 +105,37 @@ def __init__(
assert (
log_probs.shape == (self.max_length, self.n_trajectories)
and log_probs.dtype == torch.float
), f"log_probs.shape={log_probs.shape}, self.max_length={self.max_length}, self.n_trajectories={self.n_trajectories}"
), f"log_probs.shape={log_probs.shape}, "
f"self.max_length={self.max_length}, "
f"self.n_trajectories={self.n_trajectories}"
else:
log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float)
self.log_probs: torch.Tensor = log_probs

self.estimator_outputs = estimator_outputs
if self.estimator_outputs is not None:
# assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape TODO: check why fails
# TODO: check why this fails.
# assert self.estimator_outputs.shape[:len(self.states.batch_shape)] == self.states.batch_shape
assert self.estimator_outputs.dtype == torch.float

def __repr__(self) -> str:
states = self.states.tensor.transpose(0, 1)
assert states.ndim == 3
trajectories_representation = ""
assert isinstance(
self.env.s0, torch.Tensor
), "not supported for Graph trajectories."
assert isinstance(
self.env.sf, torch.Tensor
), "not supported for Graph trajectories."

for traj in states[:10]:
one_traj_repr = []
for step in traj:
one_traj_repr.append(str(step.cpu().numpy()))
if step.equal(self.env.s0 if self.is_backward else self.env.sf):
if self.is_backward and step.equal(self.env.s0):
break
elif not self.is_backward and step.equal(self.env.sf):
break
trajectories_representation += "-> ".join(one_traj_repr) + "\n"
return (
Expand Down Expand Up @@ -482,6 +496,10 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories:
new_actions[torch.arange(len(self)), seq_lengths] = self.env.exit_action

# Assign reversed states to new_states
assert isinstance(states[:, -1], torch.Tensor)
assert isinstance(
self.env.s0, torch.Tensor
), "reverse_backward_trajectories not supported for Graph trajectories"
assert torch.all(states[:, -1] == self.env.s0), "Last state must be s0"
new_states[:, 0] = self.env.s0
new_states[:, 1:-1][mask] = states[:, :-1][mask][rev_idx[mask]]
Expand Down Expand Up @@ -536,7 +554,7 @@ def reverse_backward_trajectories(self, debug: bool = False) -> Trajectories:


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
"""Pads tensor a to match the dimention of b."""
"""Pads tensor a to match the dimension of b."""
assert a.shape[0] < target_dim0, "a is already larger than target_dim0!"
pad_dim = target_dim0 - a.shape[0]
pad_dim_full = (pad_dim,) + tuple(a.shape[1:])
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(
the children of the transitions.
is_backward: Whether the transitions are backward transitions (i.e.
`next_states` is the parent of states).
log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a default value like
`-float('inf')` for non-terminating transitions).
log_rewards: Tensor of shape (n_transitions,) containing the log-rewards of the transitions (using a
default value like `-float('inf')` for non-terminating transitions).
log_probs: Tensor of shape (n_transitions,) containing the log-probabilities of the actions.

Raises:
Expand Down
52 changes: 19 additions & 33 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional, Tuple, Union, cast

import torch
from tensordict import TensorDict
from torch_geometric.data import Batch as GeometricBatch
from torch_geometric.data import Data as GeometricData

from gfn.actions import Actions, GraphActions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
Expand All @@ -23,12 +24,12 @@ class Env(ABC):

def __init__(
self,
s0: torch.Tensor | TensorDict,
s0: torch.Tensor | GeometricData,
state_shape: Tuple,
action_shape: Tuple,
dummy_action: torch.Tensor,
exit_action: torch.Tensor,
sf: Optional[torch.Tensor | TensorDict] = None,
sf: Optional[torch.Tensor | GeometricData] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
Expand All @@ -51,7 +52,7 @@ def __init__(
"""
self.device = get_device(device_str, default_device=s0.device)

self.s0 = s0.to(self.device)
self.s0 = s0.to(self.device) # pyright: ignore
assert s0.shape == state_shape
if sf is None:
sf = torch.full(s0.shape, -float("inf")).to(self.device)
Expand Down Expand Up @@ -229,20 +230,7 @@ def reset(
batch_shape=batch_shape, random=random, sink=sink
)

def validate_actions(
self, states: States, actions: Actions, backward: bool = False
) -> bool:
"""First, asserts that states and actions have the same batch_shape.
Then, uses `is_action_valid`.
Returns a boolean indicating whether states/actions pairs are valid."""
assert states.batch_shape == actions.batch_shape
return self.is_action_valid(states, actions, backward)

def _step(
self,
states: States,
actions: Actions,
) -> States:
def _step(self, states: States, actions: Actions) -> States:
"""Core step function. Calls the user-defined self.step() function.

Function that takes a batch of states and actions and returns a batch of next
Expand All @@ -256,7 +244,7 @@ def _step(
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]

if not self.validate_actions(valid_states, valid_actions):
if not self.is_action_valid(valid_states, valid_actions):
raise NonValidActionsError(
"Some actions are not valid in the given states. See `is_action_valid`."
)
Expand All @@ -275,19 +263,15 @@ def _step(

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)

if not isinstance(new_not_done_states_tensor, (torch.Tensor, TensorDict)):
if not isinstance(new_not_done_states_tensor, (torch.Tensor, GeometricBatch)):
raise Exception(
"User implemented env.step function *must* return a torch.Tensor!"
)

new_states[~new_sink_states_idx] = self.States(new_not_done_states_tensor)
return new_states

def _backward_step(
self,
states: States,
actions: Actions,
) -> States:
def _backward_step(self, states: States, actions: Actions) -> States:
"""Core backward_step function. Calls the user-defined self.backward_step fn.

This function takes a batch of states and actions and returns a batch of next
Expand All @@ -301,7 +285,7 @@ def _backward_step(
valid_actions = actions[valid_states_idx]
valid_states = states[valid_states_idx]

if not self.validate_actions(valid_states, valid_actions, backward=True):
if not self.is_action_valid(valid_states, valid_actions, backward=True):
raise NonValidActionsError(
"Some actions are not valid in the given states. See `is_action_valid`."
)
Expand Down Expand Up @@ -590,12 +574,12 @@ def terminating_states(self) -> DiscreteStates:
class GraphEnv(Env):
"""Base class for graph-based environments."""

sf: TensorDict # this tells the type checker that sf is a TensorDict
sf: GeometricData # this tells the type checker that sf is a GeometricData

def __init__(
self,
s0: TensorDict,
sf: TensorDict,
s0: GeometricData,
sf: GeometricData,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
):
Expand All @@ -604,14 +588,16 @@ def __init__(
Args:
s0: The initial graph state.
sf: The sink graph state.
device_str: 'cpu' or 'cuda'. Defaults to None, in which case the device is
inferred from s0.
device_str: String representation of the device.
preprocessor: a Preprocessor object that converts raw graph states to a tensor
that can be fed into a neural network. Defaults to None, in which case
the IdentityPreprocessor is used.
"""
self.s0 = s0.to(device_str)
self.features_dim = s0["node_feature"].shape[-1]
device = get_device(device_str, default_device=s0.device)
assert s0.x is not None

self.s0 = s0.to(device) # pyright: ignore
self.features_dim = s0.x.shape[-1]
self.sf = sf

self.States = self.make_states_class()
Expand Down
2 changes: 2 additions & 0 deletions src/gfn/gym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .box import Box
from .discrete_ebm import DiscreteEBM
from .graph_building import GraphBuilding
from .hypergrid import HyperGrid
from .line import Line

Expand All @@ -8,4 +9,5 @@
"DiscreteEBM",
"HyperGrid",
"Line",
"GraphBuilding",
]
Loading